From 9246da8970924d5cbfa5233bb738847950f376f8 Mon Sep 17 00:00:00 2001 From: "kinfai.kan" Date: Tue, 19 Feb 2019 00:05:58 -0800 Subject: [PATCH] modify OpXGBoostClassificationModel to generate two raw prediction values for binary classification --- .../classification/OpXGBoostClassifier.scala | 3 ++- .../classification/OpClassifierModelTest.scala | 12 +++++++++--- .../classification/OpXGBoostClassifierTest.scala | 16 ++++++++-------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifier.scala b/core/src/main/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifier.scala index bdfc3d5d41..30f9801fb4 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifier.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifier.scala @@ -366,9 +366,10 @@ class OpXGBoostClassificationModel val data = removeMissingValues(Iterator(features.value.asXGB), missing) val dm = new DMatrix(dataIter = data) val rawPred = booster.predict(dm, outPutMargin = true, treeLimit = treeLimit)(0).map(_.toDouble) + val rawPrediction = if (model.numClasses == 2) Array(-rawPred(0), rawPred(0)) else rawPred val prob = booster.predict(dm, outPutMargin = false, treeLimit = treeLimit)(0).map(_.toDouble) val probability = if (model.numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob val prediction = probability2predictionMirror(Vectors.dense(probability)).asInstanceOf[Double] - Prediction(prediction = prediction, rawPrediction = rawPred, probability = probability) + Prediction(prediction = prediction, rawPrediction = rawPrediction, probability = probability) } } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpClassifierModelTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpClassifierModelTest.scala index c6f44b2cfe..ab8168cd6c 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpClassifierModelTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpClassifierModelTest.scala @@ -134,10 +134,15 @@ class OpClassifierModelTest extends FlatSpec with TestSparkContext with OpXGBoos .setLabelCol(labelF.name) val spk = cl.fit(rawDF) val op = toOP(spk, spk.uid).setInput(labelF, featureV) - compareOutputs(spk.transform(rawDF), op.transform(rawDF)) + compareOutputs(spk.transform(rawDF), op.transform(rawDF), false) } - def compareOutputs(df1: DataFrame, df2: DataFrame)(implicit arrayEquality: Equality[Array[Double]]): Unit = { + def compareOutputs + ( + df1: DataFrame, + df2: DataFrame, + fullRawPred: Boolean = true + )(implicit arrayEquality: Equality[Array[Double]]): Unit = { def keysStartsWith(name: String, value: Map[String, Double]): Array[Double] = { val names = value.keys.filter(_.startsWith(name)).toArray.sorted names.map(value) @@ -148,7 +153,8 @@ class OpClassifierModelTest extends FlatSpec with TestSparkContext with OpXGBoos val map = r2.getAs[Map[String, Double]](2) r1.getAs[Double](4) shouldEqual map(Prediction.Keys.PredictionName) r1.getAs[Vector](3).toArray shouldEqual keysStartsWith(Prediction.Keys.ProbabilityName, map) - r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map) + if (fullRawPred) r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map) + else r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map).tail } } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifierTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifierTest.scala index 77760810db..e7d5c54a1e 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifierTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/classification/OpXGBoostClassifierTest.scala @@ -64,14 +64,14 @@ class OpXGBoostClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWra estimator.setSilent(1) val expectedResult = Seq( - Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), - Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), - Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), - Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), - Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), - Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), - Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), - Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)) + Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), + Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), + Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), + Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), + Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), + Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), + Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), + Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)) ) it should "allow the user to set the desired spark parameters" in {