Skip to content

Commit

Permalink
modify OpXGBoostClassificationModel to generate two raw prediction va…
Browse files Browse the repository at this point in the history
…lues for binary classification (#229)
  • Loading branch information
kinfaikan authored and tovbinm committed Feb 19, 2019
1 parent a893747 commit bfc8312
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ class OpXGBoostClassificationModel
val data = removeMissingValues(Iterator(features.value.asXGB), missing)
val dm = new DMatrix(dataIter = data)
val rawPred = booster.predict(dm, outPutMargin = true, treeLimit = treeLimit)(0).map(_.toDouble)
val rawPrediction = if (model.numClasses == 2) Array(-rawPred(0), rawPred(0)) else rawPred
val prob = booster.predict(dm, outPutMargin = false, treeLimit = treeLimit)(0).map(_.toDouble)
val probability = if (model.numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
val prediction = probability2predictionMirror(Vectors.dense(probability)).asInstanceOf[Double]
Prediction(prediction = prediction, rawPrediction = rawPred, probability = probability)
Prediction(prediction = prediction, rawPrediction = rawPrediction, probability = probability)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,15 @@ class OpClassifierModelTest extends FlatSpec with TestSparkContext with OpXGBoos
.setLabelCol(labelF.name)
val spk = cl.fit(rawDF)
val op = toOP(spk, spk.uid).setInput(labelF, featureV)
compareOutputs(spk.transform(rawDF), op.transform(rawDF))
compareOutputs(spk.transform(rawDF), op.transform(rawDF), false)
}

def compareOutputs(df1: DataFrame, df2: DataFrame)(implicit arrayEquality: Equality[Array[Double]]): Unit = {
def compareOutputs
(
df1: DataFrame,
df2: DataFrame,
fullRawPred: Boolean = true
)(implicit arrayEquality: Equality[Array[Double]]): Unit = {
def keysStartsWith(name: String, value: Map[String, Double]): Array[Double] = {
val names = value.keys.filter(_.startsWith(name)).toArray.sorted
names.map(value)
Expand All @@ -148,7 +153,8 @@ class OpClassifierModelTest extends FlatSpec with TestSparkContext with OpXGBoos
val map = r2.getAs[Map[String, Double]](2)
r1.getAs[Double](4) shouldEqual map(Prediction.Keys.PredictionName)
r1.getAs[Vector](3).toArray shouldEqual keysStartsWith(Prediction.Keys.ProbabilityName, map)
r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map)
if (fullRawPred) r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map)
else r1.getAs[Vector](2).toArray shouldEqual keysStartsWith(Prediction.Keys.RawPredictionName, map).tail
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class OpXGBoostClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWra
estimator.setSilent(1)

val expectedResult = Seq(
Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(1.0, Array(0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284))
Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284))
)

it should "allow the user to set the desired spark parameters" in {
Expand Down

0 comments on commit bfc8312

Please sign in to comment.