From 642b4412310adc8bc078700aae5eb237302cb8dd Mon Sep 17 00:00:00 2001
From: root
Date: Thu, 30 Aug 2018 13:21:27 -0400
Subject: [PATCH] [jvm-packages] Fix #3489: Spark repartitionForData can
potentially shuffle all data and lose ordering required for ranking
objectives
---
.../dmlc/xgboost4j/scala/spark/XGBoost.scala | 326 ++++++++++++------
.../scala/spark/XGBoostClassifier.scala | 3 +-
.../scala/spark/XGBoostRegressor.scala | 2 +-
.../scala/spark/XGBoostClassifierSuite.scala | 2 +-
.../scala/spark/XGBoostGeneralSuite.scala | 50 ++-
5 files changed, 274 insertions(+), 109 deletions(-)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index fa1dccc53b67..d31810098b07 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
-import scala.collection.mutable
+import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
@@ -53,6 +53,17 @@ object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
}
+/**
+ * Traing data group in a RDD partition.
+ * @param groupId The group id
+ * @param points Array of XGBLabeledPoint within the same group.
+ * @param isEdgeGroup whether it is a frist or last group in a RDD partition.
+ */
+private[spark] case class XGBLabeledPointGroup(
+ groupId: Int,
+ points: Array[XGBLabeledPoint],
+ isEdgeGroup: Boolean)
+
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
@@ -74,78 +85,62 @@ object XGBoost extends Serializable {
}
}
- private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
- val builder = new mutable.ArrayBuilder.ofFloat()
- var nTotal = 0
- var nUndefined = 0
- while (baseMargins.hasNext) {
- nTotal += 1
- val baseMargin = baseMargins.next()
- if (baseMargin.isNaN) {
- nUndefined += 1 // don't waste space for all-NaNs.
- } else {
- builder += baseMargin
+ private def removeMissingValuesWithGroup(
+ xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
+ missing: Float): Iterator[Array[XGBLabeledPoint]] = {
+ if (!missing.isNaN) {
+ xgbLabelPointGroups.map {
+ labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
}
+ } else {
+ xgbLabelPointGroups
}
- if (nUndefined == nTotal) {
- None
- } else if (nUndefined == 0) {
- Some(builder.result())
+ }
+
+ private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
+ val taskId = TaskContext.getPartitionId().toString
+ if (useExternalMemory) {
+ val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
+ Some(dir.toAbsolutePath.toString)
} else {
- throw new IllegalArgumentException(
- s"Encountered a partition with $nUndefined NaN base margin values. " +
- s"If you want to specify base margin, ensure all values are non-NaN.")
+ None
}
}
- private[spark] def buildDistributedBoosters(
- data: RDD[XGBLabeledPoint],
+ private def buildDistributedBooster(
+ watches: Watches,
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
- useExternalMemory: Boolean,
- missing: Float,
- prevBooster: Booster
- ): RDD[(Booster, Map[String, Array[Float]])] = {
+ prevBooster: Booster)
+ : Iterator[(Booster, Map[String, Array[Float]])] = {
- val partitionedBaseMargin = data.map(_.baseMargin)
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
- data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
- if (labeledPoints.isEmpty) {
- throw new XGBoostError(
- s"detected an empty partition in the training data, partition ID:" +
- s" ${TaskContext.getPartitionId()}")
- }
- val taskId = TaskContext.getPartitionId().toString
- val cacheDirName = if (useExternalMemory) {
- val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
- Some(dir.toAbsolutePath.toString)
- } else {
- None
- }
- rabitEnv.put("DMLC_TASK_ID", taskId)
- Rabit.init(rabitEnv)
- val watches = Watches(params,
- removeMissingValues(labeledPoints, missing),
- fromBaseMarginsToArray(baseMargins), cacheDirName)
-
- try {
- val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
- .map(_.toString.toInt).getOrElse(0)
- val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
- val booster = SXGBoost.train(watches.train, params, round,
- watches.toMap, metrics, obj, eval,
- earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
- Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
- } finally {
- Rabit.shutdown()
- watches.delete()
- }
- }.cache()
+ if (watches.train.rowNum == 0) {
+ throw new XGBoostError(
+ s"detected an empty partition in the training data, partition ID:" +
+ s" ${TaskContext.getPartitionId()}")
+ }
+ val taskId = TaskContext.getPartitionId().toString
+ rabitEnv.put("DMLC_TASK_ID", taskId)
+ Rabit.init(rabitEnv)
+
+ try {
+ val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
+ .map(_.toString.toInt).getOrElse(0)
+ val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
+ val booster = SXGBoost.train(watches.train, params, round,
+ watches.toMap, metrics, obj, eval,
+ earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
+ Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
+ } finally {
+ Rabit.shutdown()
+ watches.delete()
+ }
}
private def overrideParamsAccordingToTaskCPUs(
@@ -219,7 +214,8 @@ object XGBoost extends Serializable {
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
- missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
+ missing: Float = Float.NaN,
+ hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
validateSparkSslConf(trainingData.context)
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
@@ -244,7 +240,6 @@ object XGBoost extends Serializable {
" an instance of Long.")
}
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
- val partitionedData = repartitionForTraining(trainingData, nWorkers)
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
@@ -258,9 +253,29 @@ object XGBoost extends Serializable {
try {
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
- val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
- tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
- prevBooster)
+ val rabitEnv = tracker.getWorkerEnvs
+ val boostersAndMetrics = hasGroup match {
+ case true => {
+ val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
+ partitionedData.mapPartitions(labeledPointGroups => {
+ val watches = Watches.buildWatchesWithGroup(params,
+ removeMissingValuesWithGroup(labeledPointGroups, missing),
+ getCacheDirName(useExternalMemory))
+ buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
+ obj, eval, prevBooster)
+ }).cache()
+ }
+ case false => {
+ val partitionedData = repartitionForTraining(trainingData, nWorkers)
+ partitionedData.mapPartitions(labeledPoints => {
+ val watches = Watches.buildWatches(params,
+ removeMissingValues(labeledPoints, missing),
+ getCacheDirName(useExternalMemory))
+ buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
+ obj, eval, prevBooster)
+ }).cache()
+ }
+ }
val sparkJobThread = new Thread() {
override def run() {
// force the job
@@ -278,13 +293,12 @@ object XGBoost extends Serializable {
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics)
- } finally {
- tracker.stop()
- }
+ } finally {
+ tracker.stop()
+ }
}.last
}
-
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
if (trainingData.getNumPartitions != nWorkers) {
logger.info(s"repartitioning training set to $nWorkers partitions")
@@ -294,6 +308,31 @@ object XGBoost extends Serializable {
}
}
+ private[spark] def repartitionForTrainingGroup(
+ trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
+ val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
+ // LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
+ new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
+
+ // edge groups with partition id.
+ val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
+ new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
+ group => (TaskContext.getPartitionId(), group))
+
+ // group chunks from different partitions together by group id in XGBLabeledPoint.
+ // use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
+ val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
+ groups => {
+ val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
+ // sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
+ it.toArray.sortBy(_._1).map(_._2.points).flatten
+ })
+
+ var allGroups = normalGroups.union(stitchedGroups)
+ logger.info(s"repartitioning training group set to $nWorkers partitions")
+ allGroups.repartition(nWorkers)
+ }
+
private def postTrackerReturnProcessing(
trackerReturnVal: Int,
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
@@ -321,9 +360,9 @@ object XGBoost extends Serializable {
}
private class Watches private(
- val train: DMatrix,
- val test: DMatrix,
- private val cacheDirName: Option[String]) {
+ val train: DMatrix,
+ val test: DMatrix,
+ private val cacheDirName: Option[String]) {
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
.filter { case (_, matrix) => matrix.rowNum > 0 }
@@ -342,59 +381,152 @@ private class Watches private(
private object Watches {
- def buildGroups(groups: Seq[Int]): Seq[Int] = {
- val output = mutable.ArrayBuffer.empty[Int]
- var count = 1
- var lastGroup = groups.head
- for (group <- groups.tail) {
- if (group != lastGroup) {
- lastGroup = group
- output += count
- count = 1
+ private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
+ val builder = new mutable.ArrayBuilder.ofFloat()
+ var nTotal = 0
+ var nUndefined = 0
+ while (baseMargins.hasNext) {
+ nTotal += 1
+ val baseMargin = baseMargins.next()
+ if (baseMargin.isNaN) {
+ nUndefined += 1 // don't waste space for all-NaNs.
} else {
- count += 1
+ builder += baseMargin
}
}
- output += count
- output
+ if (nUndefined == nTotal) {
+ None
+ } else if (nUndefined == 0) {
+ Some(builder.result())
+ } else {
+ throw new IllegalArgumentException(
+ s"Encountered a partition with $nUndefined NaN base margin values. " +
+ s"If you want to specify base margin, ensure all values are non-NaN.")
+ }
}
- def apply(
+ def buildWatches(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
- baseMarginsOpt: Option[Array[Float]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
+ val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
+ val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
testPoints += labeledPoint
+ testBaseMargins += labeledPoint.baseMargin
+ } else {
+ trainBaseMargins += labeledPoint.baseMargin
}
+ accepted
+ }
+ val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
+ val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
+
+ val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
+ val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
+ if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
+ if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
+
+ new Watches(trainMatrix, testMatrix, cacheDirName)
+ }
+ def buildWatchesWithGroup(
+ params: Map[String, Any],
+ labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
+ cacheDirName: Option[String]): Watches = {
+ val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
+ val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
+ val r = new Random(seed)
+ val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
+ val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
+ val testBaseMargins = new mutable.ArrayBuilder.ofFloat
+ val trainGroups = new mutable.ArrayBuilder.ofInt
+ val testGroups = new mutable.ArrayBuilder.ofInt
+
+ val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
+ val accepted = r.nextDouble() <= trainTestRatio
+ if (!accepted) {
+ labeledPointGroup.foreach(labeledPoint => {
+ testPoints += labeledPoint
+ testBaseMargins += labeledPoint.baseMargin
+ })
+ testGroups += labeledPointGroup.length
+ } else {
+ labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
+ trainGroups += labeledPointGroup.length
+ }
accepted
}
- val (trainIter1, trainIter2) = trainPoints.duplicate
- val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
- val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
- trainMatrix.setGroup(trainGroups)
+ val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
+ val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
+ trainMatrix.setGroup(trainGroups.result())
- val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
+ val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
if (trainTestRatio < 1.0) {
- val testGroups = buildGroups(testPoints.map(_.group)).toArray
- testMatrix.setGroup(testGroups)
+ testMatrix.setGroup(testGroups.result())
}
- r.setSeed(seed)
- for (baseMargins <- baseMarginsOpt) {
- val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
- trainMatrix.setBaseMargin(trainMargin)
- testMatrix.setBaseMargin(testMargin)
- }
+ val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
+ val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
+ if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
+ if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
new Watches(trainMatrix, testMatrix, cacheDirName)
}
}
+
+/**
+ * Within each RDD partition, group the XGBLabeledPoint
by group id.
+ * And the first and the last groups may not have all the items due to the data partition.
+ * LabeledPointGroupIterator
orginaizes data in a tuple format:
+ * (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).
+ * The edge groups across partitions can be stitched together later.
+ * @param base collection of XGBLabeledPoint
+ */
+private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
+ extends AbstractIterator[XGBLabeledPointGroup] {
+
+ private var firstPointOfNextGroup: XGBLabeledPoint = null
+ private var isNewGroup = true
+
+ override def hasNext: Boolean = {
+ return base.hasNext || isNewGroup
+ }
+
+ override def next(): XGBLabeledPointGroup = {
+ val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
+ var isFirstGroup = true
+ if (firstPointOfNextGroup != null) {
+ builder += firstPointOfNextGroup
+ isFirstGroup = false
+ }
+
+ isNewGroup = false
+ while (!isNewGroup && base.hasNext) {
+ val point = base.next()
+ val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
+ firstPointOfNextGroup = point
+ if (point.group == groupId) {
+ // add to current group
+ builder += point
+ } else {
+ // start a new group
+ isNewGroup = true
+ }
+ }
+
+ val isLastGroup = !isNewGroup
+ val result = builder.result()
+ val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
+
+ group
+ }
+}
+
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
index c8ac28eb4057..869c1fe9c30e 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
@@ -196,7 +196,7 @@ class XGBoostClassifier (
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
- $(missing))
+ $(missing), hasGroup = false)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
@@ -517,3 +517,4 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
}
}
}
+
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
index 277d5566974c..5ce659bb0774 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
@@ -191,7 +191,7 @@ class XGBoostRegressor (
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
- $(missing))
+ $(missing), hasGroup = group != lit(-1))
val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
index 9c92f8810dc9..8b7bdb44821e 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
@@ -173,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val training2 = training1.withColumn("margin", functions.rand())
val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "test_train_split" -> "0.5",
+ "objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb = new XGBoostClassifier(paramMap)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 05e9ae75735d..243ab4fc24b1 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import ml.dmlc.xgboost4j.java.Rabit
+import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -71,18 +72,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
- test("build RDD containing boosters with the specified worker number") {
+ test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
- val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
- val boosterRDD = XGBoost.buildDistributedBoosters(
- partitionedRDD,
+ val (booster, metrics) = XGBoost.trainDistributed(
+ trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
- new java.util.HashMap[String, String](),
- round = 5, eval = null, obj = null, useExternalMemory = true,
- missing = Float.NaN, prevBooster = null)
- val boosterCount = boosterRDD.count()
- assert(boosterCount === 2)
+ round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
+ hasGroup = false, missing = Float.NaN)
+
+ assert(booster != null)
}
test("training with external memory cache") {
@@ -235,4 +234,37 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
+
+ test("repartitionForTrainingGroup with group data") {
+ // test different splits to cover the corner cases.
+ for (split <- 1 to 20) {
+ val trainingRDD = sc.parallelize(Ranking.train, split)
+ val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
+ val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
+ // check the the order of the groups with group id.
+ // Ranking.train has 20 groups
+ assert(trainingGroups.length == 20)
+
+ // compare all points
+ val allPoints = trainingGroups.sortBy(_(0).group).flatten
+ assert(allPoints.length == Ranking.train.size)
+ for (i <- 0 to Ranking.train.size - 1) {
+ assert(allPoints(i).group == Ranking.train(i).group)
+ assert(allPoints(i).label == Ranking.train(i).label)
+ assert(allPoints(i).values.sameElements(Ranking.train(i).values))
+ }
+ }
+ }
+
+ test("distributed training with group data") {
+ val trainingRDD = sc.parallelize(Ranking.train, 2)
+ val (booster, metrics) = XGBoost.trainDistributed(
+ trainingRDD,
+ List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
+ "objective" -> "binary:logistic").toMap,
+ round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
+ hasGroup = true, missing = Float.NaN)
+
+ assert(booster != null)
+ }
}