From 8277267e1fc817b431596009b704ae7ebd75d3ae Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 24 Jul 2023 14:24:21 +0800 Subject: [PATCH 1/2] [jvm-packages] set device to cuda when tree method is "gpu_hist" --- .../ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5fc16ec0937b..572a426c330a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,11 +17,9 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File - import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ - import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager @@ -30,7 +28,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.FileSystem - import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.sql.SparkSession @@ -180,10 +177,12 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s " as 'hist', 'approx', 'gpu_hist', and 'auto'") treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) } - val device: Option[String] = overridedParams.get("device") match { - case None => None - case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev) - } + + // back-compatible with "gpu_hist" + val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) { + Some("cuda") + } else overridedParams.get("device").map(_.toString) + if (overridedParams.contains("train_test_ratio")) { logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + From 253083d34a21073a2558d865a3ce4fd0f71d4dff Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 24 Jul 2023 16:16:49 +0800 Subject: [PATCH 2/2] format --- .../src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 572a426c330a..f514eaa68b20 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,9 +17,11 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File + import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ + import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager @@ -28,6 +30,7 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.FileSystem + import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.sql.SparkSession