diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index a2ea4444378f..53a06aa2fb24 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -133,7 +133,6 @@ abstract class XGBoostModel(protected var _booster: Booster) } } } finally { - Rabit.shutdown() dMatrix.delete() } } else { @@ -151,7 +150,7 @@ abstract class XGBoostModel(protected var _booster: Booster) * @param testSet test set represented as RDD * @param missingValue the specified value to represent the missing value */ - def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { + def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val sampleArray = testSamples.toList @@ -169,7 +168,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() @@ -188,7 +187,7 @@ abstract class XGBoostModel(protected var _booster: Booster) def predict( testSet: RDD[MLVector], useExternalCache: Boolean = false, - outputMargin: Boolean = false): RDD[Array[Array[Float]]] = { + outputMargin: Boolean = false): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) val appName = testSet.context.appName testSet.mapPartitions { testSamples => @@ -205,7 +204,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index d4007401bf1d..c5724270f66e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -252,7 +252,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { "objective" -> "binary:logistic") val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val predRDD = xgBoostModel.predict(testRDD) - val predResult1 = predRDD.collect()(0) + val predResult1 = predRDD.collect() assert(testRDD.count() === predResult1.length) import DataUtils._ val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) @@ -358,7 +358,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) @@ -386,7 +386,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) } @@ -403,7 +403,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val trainMargin = { XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2) .predict(trainRDD.map(_.features), outputMargin = true) - .flatMap { _.flatten.iterator } + .flatMap { _.iterator } } val xgBoostModel = XGBoost.trainWithRDD(