Skip to content

Commit

Permalink
restore use_gpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 17, 2023
1 parent 1b5394c commit 9b57a90
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 20 deletions.
10 changes: 7 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,13 +1004,17 @@ def fit(
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.6.0
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
Use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
instead.
verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage.
Expand Down
32 changes: 16 additions & 16 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,18 @@ def _validate_params(self) -> None:
f"It cannot be less than 1 [Default is 1]"
)

tree_method = self.getOrDefault(self.getParam("tree_method"))
if (
self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device))
) and not _can_use_qdm(tree_method):
raise ValueError(
f"The `{tree_method}` tree method is not supported on GPU."
)

if self.getOrDefault(self.features_cols):
if not use_cuda(self.getOrDefault(self.device)):
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault(
self.use_gpu
):
raise ValueError(
"features_col param with list value requires `device=cuda`."
)
Expand Down Expand Up @@ -396,7 +406,7 @@ def _validate_params(self) -> None:
"`pyspark.ml.linalg.Vector` type."
)

if use_cuda(self.getOrDefault(self.device)):
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
gpu_per_task = (
_get_spark_session()
.sparkContext.getConf()
Expand Down Expand Up @@ -553,6 +563,7 @@ def __init__(self) -> None:
self._setDefault(
num_workers=1,
device="cpu",
use_gpu=False,
force_repartition=False,
repartition_random_shuffle=False,
feature_names=None,
Expand Down Expand Up @@ -874,20 +885,9 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
dmatrix_kwargs,
) = self._get_xgb_parameters(dataset)

run_on_gpu = use_cuda(self.getOrDefault(self.device))
tree_method = self.getParam("tree_method")
# Validation before submitting function to worker.
if (
run_on_gpu
and self.getOrDefault(tree_method)
and self.getOrDefault(tree_method) != "hist"
):
raise ValueError(
f"The `{self.getOrDefault(tree_method)}` tree method is"
" not supported"
" on GPU."
)

run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
self.use_gpu
)
is_local = _is_local(_get_spark_session().sparkContext)

num_workers = self.getOrDefault(self.num_workers)
Expand Down
29 changes: 28 additions & 1 deletion python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=unused-argument, too-many-locals


import warnings
from typing import Any, List, Optional, Type, Union

import numpy as np
Expand Down Expand Up @@ -134,6 +134,10 @@ class SparkXGBRegressor(_SparkXGBEstimator):
num_workers:
How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task.
use_gpu:
.. deprecated:: 2.0.0
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition:
Expand Down Expand Up @@ -194,6 +198,7 @@ def __init__(
weight_col: Optional[str] = None,
base_margin_col: Optional[str] = None,
num_workers: int = 1,
use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
Expand All @@ -202,6 +207,10 @@ def __init__(
) -> None:
super().__init__()
input_kwargs = self._input_kwargs
if use_gpu:
warnings.warn(
"`use_gpu` is deprecated, use `device` instead", FutureWarning
)
self.setParams(**input_kwargs)

@classmethod
Expand Down Expand Up @@ -302,6 +311,10 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
num_workers:
How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task.
use_gpu:
.. deprecated:: 2.0.0
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition:
Expand Down Expand Up @@ -362,6 +375,7 @@ def __init__(
weight_col: Optional[str] = None,
base_margin_col: Optional[str] = None,
num_workers: int = 1,
use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
Expand All @@ -374,6 +388,10 @@ def __init__(
# binary or multinomial input dataset, and we need to remove the fixed default
# param value as well to avoid causing ambiguity.
input_kwargs = self._input_kwargs
if use_gpu:
warnings.warn(
"`use_gpu` is deprecated, use `device` instead", FutureWarning
)
self.setParams(**input_kwargs)
self._setDefault(objective=None)

Expand Down Expand Up @@ -473,6 +491,10 @@ class SparkXGBRanker(_SparkXGBEstimator):
num_workers:
How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task.
use_gpu:
.. deprecated:: 2.0.0
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition:
Expand Down Expand Up @@ -539,6 +561,7 @@ def __init__(
base_margin_col: Optional[str] = None,
qid_col: Optional[str] = None,
num_workers: int = 1,
use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
Expand All @@ -547,6 +570,10 @@ def __init__(
) -> None:
super().__init__()
input_kwargs = self._input_kwargs
if use_gpu:
warnings.warn(
"`use_gpu` is deprecated, use `device` instead", FutureWarning
)
self.setParams(**input_kwargs)

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
f1 = evaluator.evaluate(pred_result_df)
assert f1 >= 0.97

clf = SparkXGBClassifier(
features_col=feature_names, use_gpu=True, num_workers=num_workers
)
grid = ParamGridBuilder().addGrid(clf.max_depth, [6, 8]).build()
evaluator = MulticlassClassificationEvaluator(metricName="f1")
cv = CrossValidator(
estimator=clf, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3
)
cvModel = cv.fit(train_df)
pred_result_df = cvModel.transform(test_df)
f1 = evaluator.evaluate(pred_result_df)
assert f1 >= 0.97


def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
from pyspark.ml.evaluation import RegressionEvaluator
Expand Down

0 comments on commit 9b57a90

Please sign in to comment.