Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XGBoost classification & regression models + Spark 2.3.2 #44

Merged
merged 78 commits into from
Oct 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
eb2da76
Initial implementation of XGBoost classifier & regressor moddels + up…
tovbinm Aug 8, 2018
e46a152
fix property name
tovbinm Aug 8, 2018
533aa3a
Minor updates
tovbinm Aug 8, 2018
4877aa3
added maven repo
tovbinm Aug 8, 2018
3a023bd
move repo to build.gradle
tovbinm Aug 8, 2018
bb547eb
Merge branch 'master' into mt/xgboost
tovbinm Aug 8, 2018
0801e1b
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 9, 2018
26c1d89
quite logging in tests
tovbinm Aug 9, 2018
b355d4c
update some tests
tovbinm Aug 9, 2018
9f7bf85
debug stuff
tovbinm Aug 9, 2018
625dd2e
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 9, 2018
75bec3f
remove line
tovbinm Aug 9, 2018
e28e13a
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 9, 2018
565868e
Fix GeneralizedLinearRegression
tovbinm Aug 9, 2018
2b85332
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 9, 2018
de7969d
Make xgboost work
tovbinm Aug 10, 2018
ea71e94
cleanup
tovbinm Aug 10, 2018
42a95cf
update test
tovbinm Aug 10, 2018
07391ad
Added test
tovbinm Aug 10, 2018
fafa9a0
Added test
tovbinm Aug 10, 2018
be245a2
make stalastyle happy
tovbinm Aug 10, 2018
c41d1b6
Merge branch 'master' into mt/xgboost
tovbinm Aug 10, 2018
7dbfd68
fix tests
tovbinm Aug 10, 2018
a3e6334
cleanup
tovbinm Aug 10, 2018
c40f264
Merge branch 'mt/xgboost' of github.com:salesforce/TransmogrifAI into…
tovbinm Aug 10, 2018
48b996c
Merge branch 'master' into mt/xgboost
tovbinm Aug 11, 2018
897fc5e
Fixed expected midpoint in DecisionTreeNumericBucketizer to reflect n…
Jauntbox Aug 13, 2018
8668a2f
Merge branch 'mt/xgboost' of github.com:salesforce/TransmogrifAI into…
Jauntbox Aug 13, 2018
418d5b4
Merge branch 'master' into mt/xgboost
tovbinm Aug 14, 2018
7dfc5a8
Merge branch 'master' into mt/xgboost
tovbinm Aug 15, 2018
58a1c7b
Merge branch 'master' into mt/xgboost
tovbinm Aug 15, 2018
56a91d9
Update workflow runner test to use tempDir
tovbinm Aug 15, 2018
7d6e139
Merge branch 'mt/xgboost' of github.com:salesforce/TransmogrifAI into…
tovbinm Aug 15, 2018
46911f4
flatmap futures
tovbinm Aug 15, 2018
6773310
Merge branch 'master' into mt/xgboost
tovbinm Aug 16, 2018
9073591
update with official 0.80 release
tovbinm Aug 17, 2018
c242b84
update double opt equality
tovbinm Aug 17, 2018
a7c9d29
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 17, 2018
aca5012
Merge branch 'master' into mt/xgboost
tovbinm Aug 17, 2018
2eba329
reuse the internal xgboost method
tovbinm Aug 17, 2018
1284475
Merge branch 'mt/xgboost' of github.com:salesforce/TransmogrifAI into…
tovbinm Aug 17, 2018
4ee9763
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 18, 2018
d008caf
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 20, 2018
55cd176
organize imports
tovbinm Aug 20, 2018
3f75c6f
Merge branch 'master' into mt/xgboost
tovbinm Aug 21, 2018
b553bbe
added xgboost contributions to model insights
tovbinm Aug 21, 2018
8e5504a
Merge branch 'master' into mt/xgboost
tovbinm Aug 23, 2018
6c1aa1b
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Aug 24, 2018
d424a88
Merge branch 'master' into mt/xgboost
tovbinm Aug 24, 2018
b20426a
minor fixes
tovbinm Aug 25, 2018
39bce39
update test
tovbinm Aug 25, 2018
a9dd209
update spec name
tovbinm Aug 25, 2018
29571d7
update tests
tovbinm Aug 25, 2018
9e658c2
final fixes
tovbinm Aug 25, 2018
60565a5
added enums
tovbinm Aug 25, 2018
7e4857c
replace spark.sparkContext with sc
tovbinm Aug 25, 2018
1a529ae
Merge branch 'master' into mt/xgboost
tovbinm Aug 25, 2018
c7ada57
Merge branch 'master' into mt/xgboost
tovbinm Aug 25, 2018
a18dc8b
Merge branch 'master' into mt/xgboost
tovbinm Aug 28, 2018
6e575fc
Merge branch 'master' into mt/xgboost
tovbinm Aug 28, 2018
1bf34d5
Merge branch 'master' into mt/xgboost
tovbinm Aug 29, 2018
77de545
Merge branch 'master' into mt/xgboost
tovbinm Aug 30, 2018
36d5837
Merge branch 'master' into mt/xgboost
tovbinm Aug 31, 2018
4165bdd
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Sep 1, 2018
91d52a8
move version
tovbinm Sep 1, 2018
e7607e5
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Sep 1, 2018
f108111
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Sep 13, 2018
0271ab6
Merge branch 'master' into mt/xgboost
tovbinm Sep 14, 2018
76a6d6f
make it compile
tovbinm Sep 14, 2018
031a37d
cleanup
tovbinm Sep 14, 2018
d38b934
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Sep 17, 2018
5ec1a32
spark 2.3.2
tovbinm Sep 28, 2018
c3b2502
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Sep 28, 2018
b97f510
Merge branch 'master' into mt/xgboost
tovbinm Oct 1, 2018
2c75617
Merge branch 'master' into mt/xgboost
tovbinm Oct 11, 2018
9d49d21
added docs
tovbinm Oct 12, 2018
55e7f27
remove enums for now
tovbinm Oct 12, 2018
5a98621
Addressed comments + docs
tovbinm Oct 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ configure(allProjs) {
scalaCheckVersion = '1.14.0'
junitVersion = '4.11'
avroVersion = '1.7.7'
sparkVersion = '2.2.1'
sparkVersion = '2.3.2'
sparkAvroVersion = '4.0.0'
scalaGraphVersion = '1.12.5'
scalafmtVersion = '1.5.1'
Expand All @@ -69,27 +69,29 @@ configure(allProjs) {
json4sVersion = '3.2.11' // matches Spark dependency version
jodaTimeVersion = '2.9.4'
jodaConvertVersion = '1.8.1'
algebirdVersion = '0.12.3'
algebirdVersion = '0.13.4'
jacksonVersion = '2.7.3'
luceneVersion = '7.3.0'
enumeratumVersion = '1.4.12'
scoptVersion = '3.5.0'
googleLibPhoneNumberVersion = '8.8.5'
googleGeoCoderVersion = '2.82'
googleCarrierVersion = '1.72'
chillAvroVersion = '0.8.0'
chillVersion = '0.8.4'
reflectionsVersion = '0.9.11'
collectionsVersion = '3.2.2'
optimaizeLangDetectorVersion = '0.0.1'
tikaVersion = '1.16'
sparkTestingBaseVersion = '2.2.0_0.8.0'
sparkTestingBaseVersion = '2.3.1_0.10.0'
sourceCodeVersion = '0.1.3'
pegdownVersion = '1.4.2'
commonsValidatorVersion = '1.6'
commonsIOVersion = '2.6'
scoveragePluginVersion = '1.3.1'
hadrianVersion = '0.8.5'
aardpfarkVersion = '0.1.0-SNAPSHOT'
xgboostVersion = '0.80'
akkaSlf4jVersion = '2.3.11'

mainClassName = 'com.salesforce.Main'
}
Expand Down
5 changes: 5 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ dependencies {

// Scopt
compile "com.github.scopt:scopt_$scalaVersion:$scoptVersion"

// XGBoost
compile "ml.dmlc:xgboost4j-spark:$xgboostVersion"
// Akka slfj4 logging (version matches XGBoost dependency)
testCompile "com.typesafe.akka:akka-slf4j_$scalaVersion:$akkaSlf4jVersion"
}
59 changes: 31 additions & 28 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.table.Alignment._
import com.salesforce.op.utils.table.Table
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostRegressionModel}
import org.apache.spark.ml.classification._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.{Model, PipelineStage, Transformer}
Expand Down Expand Up @@ -631,43 +632,45 @@ case object ModelInsights {
}

private[op] def getModelContributions(model: Option[Model[_]]): Seq[Seq[Double]] = {
model.map {
case m: SparkWrapperParams[_] => m.getSparkMlStage() match { // TODO add additional models
case Some(m: LogisticRegressionModel) => m.coefficientMatrix.rowIter.toSeq.map(_.toArray.toSeq)
case Some(m: RandomForestClassificationModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: NaiveBayesModel) => m.theta.rowIter.toSeq.map(_.toArray.toSeq)
case Some(m: DecisionTreeClassificationModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: GBTClassificationModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: LinearSVCModel) => Seq(m.coefficients.toArray.toSeq)
case Some(m: LinearRegressionModel) => Seq(m.coefficients.toArray.toSeq)
case Some(m: DecisionTreeRegressionModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: RandomForestRegressionModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: GBTRegressionModel) => Seq(m.featureImportances.toArray.toSeq)
case Some(m: GeneralizedLinearRegressionModel) => Seq(m.coefficients.toArray.toSeq)
case _ => Seq.empty[Seq[Double]]
}
case _ => Seq.empty[Seq[Double]]
}.getOrElse(Seq.empty[Seq[Double]])
val stage = model.flatMap {
case m: SparkWrapperParams[_] => m.getSparkMlStage()
case _ => None
}
val contributions = stage.collect {
case m: LogisticRegressionModel => m.coefficientMatrix.rowIter.toSeq.map(_.toArray.toSeq)
case m: RandomForestClassificationModel => Seq(m.featureImportances.toArray.toSeq)
case m: NaiveBayesModel => m.theta.rowIter.toSeq.map(_.toArray.toSeq)
case m: DecisionTreeClassificationModel => Seq(m.featureImportances.toArray.toSeq)
case m: GBTClassificationModel => Seq(m.featureImportances.toArray.toSeq)
case m: LinearSVCModel => Seq(m.coefficients.toArray.toSeq)
case m: LinearRegressionModel => Seq(m.coefficients.toArray.toSeq)
case m: DecisionTreeRegressionModel => Seq(m.featureImportances.toArray.toSeq)
case m: RandomForestRegressionModel => Seq(m.featureImportances.toArray.toSeq)
case m: GBTRegressionModel => Seq(m.featureImportances.toArray.toSeq)
case m: GeneralizedLinearRegressionModel => Seq(m.coefficients.toArray.toSeq)
case m: XGBoostRegressionModel => Seq(m.nativeBooster.getFeatureScore().values.map(_.toDouble).toSeq)
case m: XGBoostClassificationModel => Seq(m.nativeBooster.getFeatureScore().values.map(_.toDouble).toSeq)
}
contributions.getOrElse(Seq.empty)
}

private def getModelInfo(model: Option[Model[_]]): Option[ModelSelectorSummary] = {
model match {
case Some(m: SelectedModel) => Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata()))
.toOption
case Some(m: SelectedModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case _ => None
}
}

private def getStageInfo(stages: Array[OPStage]): Map[String, Any] = {
def getParams(stage: PipelineStage): Map[String, String] =
stage.extractParamMap().toSeq
.collect{
case p if p.param.name == OpPipelineStageParamsNames.InputFeatures =>
p.param.name -> p.value.asInstanceOf[Array[TransientFeature]].map(_.toJsonString()).mkString(", ")
case p if p.param.name != OpPipelineStageParamsNames.OutputMetadata &&
p.param.name != OpPipelineStageParamsNames.InputSchema => p.param.name -> p.value.toString
}.toMap

def getParams(stage: PipelineStage): Map[String, String] = {
stage.extractParamMap().toSeq.collect {
case p if p.param.name == OpPipelineStageParamsNames.InputFeatures =>
p.param.name -> p.value.asInstanceOf[Array[TransientFeature]].map(_.toJsonString()).mkString(", ")
case p if p.param.name != OpPipelineStageParamsNames.OutputMetadata &&
p.param.name != OpPipelineStageParamsNames.InputSchema => p.param.name -> p.value.toString
}.toMap
}
stages.map { s =>
val params = s match {
case m: Model[_] => getParams(if (m.hasParent) m.parent else m) // try for parent estimator so can get params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ package com.salesforce.op.evaluators

import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.salesforce.op.UID
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Tuple4Semigroup
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
Expand Down Expand Up @@ -84,6 +84,7 @@ private[op] class OpBinScoreEvaluator

// Finding stats per bin -> avg score, avg conv rate,
// total num of data points and overall brier score.
implicit val sg = new Tuple4Semigroup[Double, Double, Long, Double]()
val stats = scoreAndLabels.map {
case (score, label) =>
(getBinIndex(score, minScore, maxScore), (score, label, 1L, math.pow(score - label, 2)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.salesforce.op.UID
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Tuple2Semigroup
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{DoubleArrayParam, IntArrayParam}
Expand Down Expand Up @@ -226,8 +227,8 @@ private[op] class OpMultiClassificationEvaluator
.map(_ -> (new Array[Long](nThresholds), new Array[Long](nThresholds)))
.toMap[Label, CorrIncorr]

val agg: MetricsMap =
data.treeAggregate[MetricsMap](zeroValue)(combOp = _ + _, seqOp = _ + computeMetrics(_))
implicit val sgTuple2 = new Tuple2Semigroup[Array[Long], Array[Long]]()
val agg: MetricsMap = data.treeAggregate[MetricsMap](zeroValue)(combOp = _ + _, seqOp = _ + computeMetrics(_))

val nRows = data.count()
ThresholdMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ package com.salesforce.op.filters
import com.salesforce.op.OpParams
import com.salesforce.op.features.types._
import com.salesforce.op.features.{OPFeature, TransientFeature}
import com.salesforce.op.filters.FeatureDistribution._
import com.salesforce.op.filters.Summary._
import com.salesforce.op.readers.{DataFrameFieldNames, Reader}
import com.salesforce.op.stages.impl.feature.TimePeriod
import com.salesforce.op.stages.impl.preparators.CorrelationType
import com.salesforce.op.utils.spark.RichRow._
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Tuple2Semigroup
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -139,26 +142,27 @@ class RawFeatureFilter[T]
None
}
val predOut = allPredictors.map(TransientFeature(_))

(respOut, predOut)
}
val preparedFeatures: RDD[PreparedFeatures] =
data.rdd.map(PreparedFeatures(_, responses, predictors, timePeriod))
val preparedFeatures: RDD[PreparedFeatures] = data.rdd.map(PreparedFeatures(_, responses, predictors, timePeriod))

implicit val sgTuple2Maps = new Tuple2Semigroup[Map[FeatureKey, Summary], Map[FeatureKey, Summary]]()
// Have to use the training summaries do process scoring for comparison
val (responseSummaries, predictorSummaries): (Map[FeatureKey, Summary], Map[FeatureKey, Summary]) =
allFeatureInfo.map(info => info.responseSummaries -> info.predictorSummaries)
.getOrElse(preparedFeatures.map(_.summaries).reduce(_ + _))
val (responseSummariesArr, predictorSummariesArr): (Array[(FeatureKey, Summary)], Array[(FeatureKey, Summary)]) =
(responseSummaries.toArray, predictorSummaries.toArray)

implicit val sgTuple2Feats = new Tuple2Semigroup[Array[FeatureDistribution], Array[FeatureDistribution]]()
val (responseDistributions, predictorDistributions): (Array[FeatureDistribution], Array[FeatureDistribution]) =
preparedFeatures
.map(_.getFeatureDistributions(
responseSummaries = responseSummariesArr,
predictorSummaries = predictorSummariesArr,
bins = bins,
textBinsFormula = textBinsFormula
))
.reduce(_ + _) // NOTE: resolved semigroup is IndexedSeqSemigroup
)).reduce(_ + _)
val correlationInfo: Map[FeatureKey, Map[FeatureKey, Double]] =
allFeatureInfo.map(_.correlationInfo).getOrElse {
val responseKeys: Array[FeatureKey] = responseSummariesArr.map(_._1)
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/com/salesforce/op/filters/Summary.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ case object Summary {
val empty: Summary = Summary(Double.PositiveInfinity, Double.NegativeInfinity, 0.0, 0.0)

implicit val monoid: Monoid[Summary] = new Monoid[Summary] {
override def zero = empty
override def plus(l: Summary, r: Summary) = Summary(math.min(l.min, r.min), math.max(l.max, r.max),
l.sum + r.sum, l.count + r.count)
override def zero = Summary.empty
override def plus(l: Summary, r: Summary) = Summary(
math.min(l.min, r.min), math.max(l.max, r.max), l.sum + r.sum, l.count + r.count
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ class OpLinearSVCModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictorWrapperModel[LinearSVCModel](uid = uid, operationName = operationName, sparkModel = sparkModel) {

@transient private lazy val predictRaw = reflectMethod(getSparkMlStage().get, "predictRaw")
@transient private lazy val predict = reflectMethod(getSparkMlStage().get, "predict")
@transient lazy private val predictRaw = reflectMethod(getSparkMlStage().get, "predictRaw")
@transient lazy private val predict = reflectMethod(getSparkMlStage().get, "predict")

/**
* Function used to convert input to output
*/
override def transformFn: (RealNN, OPVector) => Prediction = (label, features) => {
val raw = predictRaw.apply(features.value).asInstanceOf[Vector]
val pred = predict.apply(features.value).asInstanceOf[Double]
val raw = predictRaw(features.value).asInstanceOf[Vector]
val pred = predict(features.value).asInstanceOf[Double]

Prediction(rawPrediction = raw, prediction = pred)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ class OpLogisticRegression(uid: String = UID[OpLogisticRegression])
class OpLogisticRegressionModel
(
sparkModel: LogisticRegressionModel,
operationName: String = classOf[LogisticRegression].getSimpleName,
uid: String = UID[OpLogisticRegressionModel]
uid: String = UID[OpLogisticRegressionModel],
operationName: String = classOf[LogisticRegression].getSimpleName
)(
implicit tti1: TypeTag[RealNN],
tti2: TypeTag[OPVector],
Expand All @@ -210,4 +210,3 @@ class OpLogisticRegressionModel
@transient lazy val probability2predictionMirror =
reflectMethod(getSparkMlStage().get, "probability2prediction")
}

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ package com.salesforce.op.stages.impl.classification
import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpProbabilisticClassifierModel}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier, OpMultilayerPerceptronClassifierParams}
import org.apache.spark.ml.linalg.Vector
Expand Down Expand Up @@ -128,7 +128,6 @@ class OpMultilayerPerceptronClassifier(uid: String = UID[OpMultilayerPerceptronC
* @param uid uid to give stage
* @param operationName unique name of the operation this stage performs
*/
// TODO in next release of spark this will be a probabilistic classifier
class OpMultilayerPerceptronClassificationModel
(
sparkModel: MultilayerPerceptronClassificationModel,
Expand All @@ -139,9 +138,12 @@ class OpMultilayerPerceptronClassificationModel
tti2: TypeTag[OPVector],
tto: TypeTag[Prediction],
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[MultilayerPerceptronClassificationModel](
) extends OpProbabilisticClassifierModel[MultilayerPerceptronClassificationModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
@transient lazy val predictRawMirror = reflectMethod(getSparkMlStage().get, "predictRaw")
@transient lazy val raw2probabilityMirror = reflectMethod(getSparkMlStage().get, "raw2probability")
@transient lazy val probability2predictionMirror =
reflectMethod(getSparkMlStage().get, "probability2prediction")
}

Loading