From 642b4412310adc8bc078700aae5eb237302cb8dd Mon Sep 17 00:00:00 2001 From: root Date: Thu, 30 Aug 2018 13:21:27 -0400 Subject: [PATCH] [jvm-packages] Fix #3489: Spark repartitionForData can potentially shuffle all data and lose ordering required for ranking objectives --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 326 ++++++++++++------ .../scala/spark/XGBoostClassifier.scala | 3 +- .../scala/spark/XGBoostRegressor.scala | 2 +- .../scala/spark/XGBoostClassifierSuite.scala | 2 +- .../scala/spark/XGBoostGeneralSuite.scala | 50 ++- 5 files changed, 274 insertions(+), 109 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index fa1dccc53b67..d31810098b07 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import java.nio.file.Files -import scala.collection.mutable +import scala.collection.{AbstractIterator, mutable} import scala.util.Random import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} @@ -53,6 +53,17 @@ object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") } +/** + * Traing data group in a RDD partition. + * @param groupId The group id + * @param points Array of XGBLabeledPoint within the same group. + * @param isEdgeGroup whether it is a frist or last group in a RDD partition. + */ +private[spark] case class XGBLabeledPointGroup( + groupId: Int, + points: Array[XGBLabeledPoint], + isEdgeGroup: Boolean) + object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") @@ -74,78 +85,62 @@ object XGBoost extends Serializable { } } - private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = { - val builder = new mutable.ArrayBuilder.ofFloat() - var nTotal = 0 - var nUndefined = 0 - while (baseMargins.hasNext) { - nTotal += 1 - val baseMargin = baseMargins.next() - if (baseMargin.isNaN) { - nUndefined += 1 // don't waste space for all-NaNs. - } else { - builder += baseMargin + private def removeMissingValuesWithGroup( + xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]], + missing: Float): Iterator[Array[XGBLabeledPoint]] = { + if (!missing.isNaN) { + xgbLabelPointGroups.map { + labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray } + } else { + xgbLabelPointGroups } - if (nUndefined == nTotal) { - None - } else if (nUndefined == 0) { - Some(builder.result()) + } + + private def getCacheDirName(useExternalMemory: Boolean): Option[String] = { + val taskId = TaskContext.getPartitionId().toString + if (useExternalMemory) { + val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId") + Some(dir.toAbsolutePath.toString) } else { - throw new IllegalArgumentException( - s"Encountered a partition with $nUndefined NaN base margin values. " + - s"If you want to specify base margin, ensure all values are non-NaN.") + None } } - private[spark] def buildDistributedBoosters( - data: RDD[XGBLabeledPoint], + private def buildDistributedBooster( + watches: Watches, params: Map[String, Any], rabitEnv: java.util.Map[String, String], round: Int, obj: ObjectiveTrait, eval: EvalTrait, - useExternalMemory: Boolean, - missing: Float, - prevBooster: Booster - ): RDD[(Booster, Map[String, Array[Float]])] = { + prevBooster: Booster) + : Iterator[(Booster, Map[String, Array[Float]])] = { - val partitionedBaseMargin = data.map(_.baseMargin) // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) - data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) => - if (labeledPoints.isEmpty) { - throw new XGBoostError( - s"detected an empty partition in the training data, partition ID:" + - s" ${TaskContext.getPartitionId()}") - } - val taskId = TaskContext.getPartitionId().toString - val cacheDirName = if (useExternalMemory) { - val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId") - Some(dir.toAbsolutePath.toString) - } else { - None - } - rabitEnv.put("DMLC_TASK_ID", taskId) - Rabit.init(rabitEnv) - val watches = Watches(params, - removeMissingValues(labeledPoints, missing), - fromBaseMarginsToArray(baseMargins), cacheDirName) - - try { - val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") - .map(_.toString.toInt).getOrElse(0) - val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) - val booster = SXGBoost.train(watches.train, params, round, - watches.toMap, metrics, obj, eval, - earlyStoppingRound = numEarlyStoppingRounds, prevBooster) - Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) - } finally { - Rabit.shutdown() - watches.delete() - } - }.cache() + if (watches.train.rowNum == 0) { + throw new XGBoostError( + s"detected an empty partition in the training data, partition ID:" + + s" ${TaskContext.getPartitionId()}") + } + val taskId = TaskContext.getPartitionId().toString + rabitEnv.put("DMLC_TASK_ID", taskId) + Rabit.init(rabitEnv) + + try { + val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") + .map(_.toString.toInt).getOrElse(0) + val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) + val booster = SXGBoost.train(watches.train, params, round, + watches.toMap, metrics, obj, eval, + earlyStoppingRound = numEarlyStoppingRounds, prevBooster) + Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) + } finally { + Rabit.shutdown() + watches.delete() + } } private def overrideParamsAccordingToTaskCPUs( @@ -219,7 +214,8 @@ object XGBoost extends Serializable { obj: ObjectiveTrait = null, eval: EvalTrait = null, useExternalMemory: Boolean = false, - missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = { + missing: Float = Float.NaN, + hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = { validateSparkSslConf(trainingData.context) if (params.contains("tree_method")) { require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + @@ -244,7 +240,6 @@ object XGBoost extends Serializable { " an instance of Long.") } val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params) - val partitionedData = repartitionForTraining(trainingData, nWorkers) val sc = trainingData.sparkContext val checkpointManager = new CheckpointManager(sc, checkpointPath) @@ -258,9 +253,29 @@ object XGBoost extends Serializable { try { val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) - val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, - tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing, - prevBooster) + val rabitEnv = tracker.getWorkerEnvs + val boostersAndMetrics = hasGroup match { + case true => { + val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers) + partitionedData.mapPartitions(labeledPointGroups => { + val watches = Watches.buildWatchesWithGroup(params, + removeMissingValuesWithGroup(labeledPointGroups, missing), + getCacheDirName(useExternalMemory)) + buildDistributedBooster(watches, params, rabitEnv, checkpointRound, + obj, eval, prevBooster) + }).cache() + } + case false => { + val partitionedData = repartitionForTraining(trainingData, nWorkers) + partitionedData.mapPartitions(labeledPoints => { + val watches = Watches.buildWatches(params, + removeMissingValues(labeledPoints, missing), + getCacheDirName(useExternalMemory)) + buildDistributedBooster(watches, params, rabitEnv, checkpointRound, + obj, eval, prevBooster) + }).cache() + } + } val sparkJobThread = new Thread() { override def run() { // force the job @@ -278,13 +293,12 @@ object XGBoost extends Serializable { checkpointManager.updateCheckpoint(prevBooster) } (booster, metrics) - } finally { - tracker.stop() - } + } finally { + tracker.stop() + } }.last } - private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { if (trainingData.getNumPartitions != nWorkers) { logger.info(s"repartitioning training set to $nWorkers partitions") @@ -294,6 +308,31 @@ object XGBoost extends Serializable { } } + private[spark] def repartitionForTrainingGroup( + trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = { + val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions( + // LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint]) + new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points) + + // edge groups with partition id. + val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions( + new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map( + group => (TaskContext.getPartitionId(), group)) + + // group chunks from different partitions together by group id in XGBLabeledPoint. + // use groupBy instead of aggregateBy since all groups within a partition have unique groud ids. + val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map( + groups => { + val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2 + // sorted by partition id and merge list of Array[XGBLabeledPoint] into one array + it.toArray.sortBy(_._1).map(_._2.points).flatten + }) + + var allGroups = normalGroups.union(stitchedGroups) + logger.info(s"repartitioning training group set to $nWorkers partitions") + allGroups.repartition(nWorkers) + } + private def postTrackerReturnProcessing( trackerReturnVal: Int, distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])], @@ -321,9 +360,9 @@ object XGBoost extends Serializable { } private class Watches private( - val train: DMatrix, - val test: DMatrix, - private val cacheDirName: Option[String]) { + val train: DMatrix, + val test: DMatrix, + private val cacheDirName: Option[String]) { def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) .filter { case (_, matrix) => matrix.rowNum > 0 } @@ -342,59 +381,152 @@ private class Watches private( private object Watches { - def buildGroups(groups: Seq[Int]): Seq[Int] = { - val output = mutable.ArrayBuffer.empty[Int] - var count = 1 - var lastGroup = groups.head - for (group <- groups.tail) { - if (group != lastGroup) { - lastGroup = group - output += count - count = 1 + private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = { + val builder = new mutable.ArrayBuilder.ofFloat() + var nTotal = 0 + var nUndefined = 0 + while (baseMargins.hasNext) { + nTotal += 1 + val baseMargin = baseMargins.next() + if (baseMargin.isNaN) { + nUndefined += 1 // don't waste space for all-NaNs. } else { - count += 1 + builder += baseMargin } } - output += count - output + if (nUndefined == nTotal) { + None + } else if (nUndefined == 0) { + Some(builder.result()) + } else { + throw new IllegalArgumentException( + s"Encountered a partition with $nUndefined NaN base margin values. " + + s"If you want to specify base margin, ensure all values are non-NaN.") + } } - def apply( + def buildWatches( params: Map[String, Any], labeledPoints: Iterator[XGBLabeledPoint], - baseMarginsOpt: Option[Array[Float]], cacheDirName: Option[String]): Watches = { val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) val r = new Random(seed) val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint] + val trainBaseMargins = new mutable.ArrayBuilder.ofFloat + val testBaseMargins = new mutable.ArrayBuilder.ofFloat val trainPoints = labeledPoints.filter { labeledPoint => val accepted = r.nextDouble() <= trainTestRatio if (!accepted) { testPoints += labeledPoint + testBaseMargins += labeledPoint.baseMargin + } else { + trainBaseMargins += labeledPoint.baseMargin } + accepted + } + val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull) + val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull) + + val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator) + val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator) + if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get) + if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get) + + new Watches(trainMatrix, testMatrix, cacheDirName) + } + def buildWatchesWithGroup( + params: Map[String, Any], + labeledPointGroups: Iterator[Array[XGBLabeledPoint]], + cacheDirName: Option[String]): Watches = { + val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) + val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) + val r = new Random(seed) + val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint] + val trainBaseMargins = new mutable.ArrayBuilder.ofFloat + val testBaseMargins = new mutable.ArrayBuilder.ofFloat + val trainGroups = new mutable.ArrayBuilder.ofInt + val testGroups = new mutable.ArrayBuilder.ofInt + + val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup => + val accepted = r.nextDouble() <= trainTestRatio + if (!accepted) { + labeledPointGroup.foreach(labeledPoint => { + testPoints += labeledPoint + testBaseMargins += labeledPoint.baseMargin + }) + testGroups += labeledPointGroup.length + } else { + labeledPointGroup.foreach(trainBaseMargins += _.baseMargin) + trainGroups += labeledPointGroup.length + } accepted } - val (trainIter1, trainIter2) = trainPoints.duplicate - val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull) - val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray - trainMatrix.setGroup(trainGroups) + val trainPoints = trainLabelPointGroups.flatMap(_.iterator) + val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull) + trainMatrix.setGroup(trainGroups.result()) - val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull) + val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull) if (trainTestRatio < 1.0) { - val testGroups = buildGroups(testPoints.map(_.group)).toArray - testMatrix.setGroup(testGroups) + testMatrix.setGroup(testGroups.result()) } - r.setSeed(seed) - for (baseMargins <- baseMarginsOpt) { - val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) - trainMatrix.setBaseMargin(trainMargin) - testMatrix.setBaseMargin(testMargin) - } + val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator) + val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator) + if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get) + if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get) new Watches(trainMatrix, testMatrix, cacheDirName) } } + +/** + * Within each RDD partition, group the XGBLabeledPoint by group id.

+ * And the first and the last groups may not have all the items due to the data partition. + * LabeledPointGroupIterator orginaizes data in a tuple format: + * (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).

+ * The edge groups across partitions can be stitched together later. + * @param base collection of XGBLabeledPoint + */ +private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint]) + extends AbstractIterator[XGBLabeledPointGroup] { + + private var firstPointOfNextGroup: XGBLabeledPoint = null + private var isNewGroup = true + + override def hasNext: Boolean = { + return base.hasNext || isNewGroup + } + + override def next(): XGBLabeledPointGroup = { + val builder = mutable.ArrayBuilder.make[XGBLabeledPoint] + var isFirstGroup = true + if (firstPointOfNextGroup != null) { + builder += firstPointOfNextGroup + isFirstGroup = false + } + + isNewGroup = false + while (!isNewGroup && base.hasNext) { + val point = base.next() + val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group + firstPointOfNextGroup = point + if (point.group == groupId) { + // add to current group + builder += point + } else { + // start a new group + isNewGroup = true + } + } + + val isLastGroup = !isNewGroup + val result = builder.result() + val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup) + + group + } +} + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index c8ac28eb4057..869c1fe9c30e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -196,7 +196,7 @@ class XGBoostClassifier ( // All non-null param maps in XGBoostClassifier are in derivedXGBParamMap. val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), - $(missing)) + $(missing), hasGroup = false) val model = new XGBoostClassificationModel(uid, _numClasses, _booster) val summary = XGBoostTrainingSummary(_metrics) model.setSummary(summary) @@ -517,3 +517,4 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] } } } + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 277d5566974c..5ce659bb0774 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -191,7 +191,7 @@ class XGBoostRegressor ( // All non-null param maps in XGBoostRegressor are in derivedXGBParamMap. val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), - $(missing)) + $(missing), hasGroup = group != lit(-1)) val model = new XGBoostRegressionModel(uid, _booster) val summary = XGBoostTrainingSummary(_metrics) model.setSummary(summary) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 9c92f8810dc9..8b7bdb44821e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -173,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val training2 = training1.withColumn("margin", functions.rand()) val test = buildDataFrame(Classification.test) val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "test_train_split" -> "0.5", + "objective" -> "binary:logistic", "train_test_ratio" -> "1.0", "num_round" -> 5, "num_workers" -> numWorkers) val xgb = new XGBoostClassifier(paramMap) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 05e9ae75735d..243ab4fc24b1 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files import java.util.concurrent.LinkedBlockingDeque import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} @@ -71,18 +72,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(collectedAllReduceResults.poll().sameElements(maxVec)) } - test("build RDD containing boosters with the specified worker number") { + test("distributed training with the specified worker number") { val trainingRDD = sc.parallelize(Classification.train) - val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2) - val boosterRDD = XGBoost.buildDistributedBoosters( - partitionedRDD, + val (booster, metrics) = XGBoost.trainDistributed( + trainingRDD, List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap, - new java.util.HashMap[String, String](), - round = 5, eval = null, obj = null, useExternalMemory = true, - missing = Float.NaN, prevBooster = null) - val boosterCount = boosterRDD.count() - assert(boosterCount === 2) + round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false, + hasGroup = false, missing = Float.NaN) + + assert(booster != null) } test("training with external memory cache") { @@ -235,4 +234,37 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(error(prevModel._booster) > error(nextModel._booster)) assert(error(nextModel._booster) < 0.1) } + + test("repartitionForTrainingGroup with group data") { + // test different splits to cover the corner cases. + for (split <- 1 to 20) { + val trainingRDD = sc.parallelize(Ranking.train, split) + val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4) + val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect() + // check the the order of the groups with group id. + // Ranking.train has 20 groups + assert(trainingGroups.length == 20) + + // compare all points + val allPoints = trainingGroups.sortBy(_(0).group).flatten + assert(allPoints.length == Ranking.train.size) + for (i <- 0 to Ranking.train.size - 1) { + assert(allPoints(i).group == Ranking.train(i).group) + assert(allPoints(i).label == Ranking.train(i).label) + assert(allPoints(i).values.sameElements(Ranking.train(i).values)) + } + } + } + + test("distributed training with group data") { + val trainingRDD = sc.parallelize(Ranking.train, 2) + val (booster, metrics) = XGBoost.trainDistributed( + trainingRDD, + List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic").toMap, + round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false, + hasGroup = true, missing = Float.NaN) + + assert(booster != null) + } }