diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 692d9e3a0e..e19fe39da7 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -74,16 +74,11 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin, run different models concurrently in different streams by creating handles in several streams. If it is None, a new one is created. - split_criterion : The criterion used to split nodes. - 0 for GINI, 1 for ENTROPY, 4 for CRITERION_END. - 2 and 3 not valid for classification - (default = 0) - split_algo : 0 for HIST and 1 for GLOBAL_QUANTILE (default = 1) - the algorithm to determine how nodes are split in the tree. - split_criterion : The criterion used to split nodes. - 0 for GINI, 1 for ENTROPY, 4 for CRITERION_END. - 2 and 3 not valid for classification - (default = 0) + split_criterion : int or string (default = 0 ('gini')) + The criterion used to split nodes. + 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, + 2 or 'mse' for MSE + 2 or 'mse' not valid for classification bootstrap : boolean (default = True) Control bootstrapping. If set, each tree in the forest is built @@ -112,17 +107,6 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin, If float, then min_samples_split represents a fraction and ceil(min_samples_split * n_rows) is the minimum number of samples for each split. - quantile_per_tree : boolean (default = False) - Whether quantile is computed for individual RF trees. - Only relevant for GLOBAL_QUANTILE split_algo. - use_experimental_backend : boolean (default = True) - If set to true and the following conditions are also met, a new - experimental backend for decision tree training will be used. The - new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE) - and `quantile_per_tree = False` (No per tree quantile computation). - The new backend is considered stable for classification tasks but - not yet for regression tasks. The RAPIDS team is continuing - optimization and evaluation of the new backend for regression tasks. n_streams : int (default = 4 ) Number of parallel streams used for forest building workers : optional, list of strings diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index 3b21810fb4..4e28dda9f7 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -68,14 +68,11 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin, run different models concurrently in different streams by creating handles in several streams. If it is None, a new one is created. - split_algo : int (default = 1) - 0 for HIST, 1 for GLOBAL_QUANTILE - The type of algorithm to be used to create the trees. - split_criterion : int (default = 2) + split_criterion : int or string (default = 2 ('mse')) The criterion used to split nodes. - 0 for GINI, 1 for ENTROPY, - 2 for MSE, 3 for MAE and 4 for CRITERION_END. - 0 and 1 not valid for regression + 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, + 2 or 'mse' for MSE + only 2 or 'mse' valid for regression bootstrap : boolean (default = True) Control bootstrapping. If set, each tree in the forest is built @@ -118,17 +115,6 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin, for median of abs error : 'median_ae' for mean of abs error : 'mean_ae' for mean square error' : 'mse' - quantile_per_tree : boolean (default = False) - Whether quantile is computed for individual RF trees. - Only relevant for GLOBAL_QUANTILE split_algo. - use_experimental_backend : boolean (default = False) - If set to true and the following conditions are also met, a new - experimental backend for decision tree training will be used. The - new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE) - and `quantile_per_tree = False` (No per tree quantile computation). - The new backend is considered stable for classification tasks but - not yet for regression tasks. The RAPIDS team is continuing - optimization and evaluation of the new backend for regression tasks. n_streams : int (default = 4 ) Number of parallel streams used for forest building workers : optional, list of strings diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index b04d5567f6..2226d5329e 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import ctypes import cupy as cp import math @@ -41,21 +40,24 @@ from cuml.common.array_descriptor import CumlArrayDescriptor class BaseRandomForestModel(Base): _param_names = ['n_estimators', 'max_depth', 'handle', 'max_features', 'n_bins', - 'split_algo', 'split_criterion', 'min_samples_leaf', + 'split_criterion', 'min_samples_leaf', 'min_samples_split', 'min_impurity_decrease', 'bootstrap', 'verbose', 'max_samples', 'max_leaves', - 'accuracy_metric', 'use_experimental_backend', + 'accuracy_metric', 'max_batch_size', 'n_streams', 'dtype', 'output_type', 'min_weight_fraction_leaf', 'n_jobs', 'max_leaf_nodes', 'min_impurity_split', 'oob_score', 'random_state', 'warm_start', 'class_weight', 'criterion'] - criterion_dict = {'0': GINI, '1': ENTROPY, '2': MSE, - '3': MAE, '4': CRITERION_END} + criterion_dict = {'0': GINI, 'gini': GINI, + '1': ENTROPY, 'entropy': ENTROPY, + '2': MSE, 'mse': MSE, + '3': MAE, 'mae': MAE, + '4': CRITERION_END} classes_ = CumlArrayDescriptor() @@ -104,14 +106,6 @@ class BaseRandomForestModel(Base): "recommended. If n_streams is > 1, results may vary " "due to stream/thread timing differences, even when " "random_state is set") - if 'use_experimental_backend' in kwargs.keys(): - warnings.warn("The 'use_experimental_backend' parameter is " - "deprecated and has no effect. " - "It will be removed in 21.10 release.") - if 'split_algo' in kwargs.keys(): - warnings.warn("The 'split_algo' parameter is " - "deprecated and has no effect. " - "It will be removed in 21.10 release.") if handle is None: handle = Handle(n_streams) @@ -247,8 +241,10 @@ class BaseRandomForestModel(Base): input_to_cuml_array(X, check_dtype=[np.float32, np.float64], order='F') if self.n_bins > self.n_rows: - raise ValueError("The number of bins,`n_bins` can not be greater" - " than the number of samples used for training.") + warnings.warn("The number of bins, `n_bins` is greater than " + "the number of samples used for training. " + "Changing `n_bins` to number of training samples.") + self.n_bins = self.n_rows if self.RF_type == CLASSIFICATION: y_m, _, _, y_dtype = \ @@ -329,14 +325,14 @@ class BaseRandomForestModel(Base): check_cols=self.n_cols) if dtype == np.float64 and not convert_dtype: - raise TypeError("GPU based predict only accepts np.float32 data. \ - Please set convert_dtype=True to convert the test \ - data to the same dtype as the data used to train, \ - ie. np.float32. If you would like to use test \ - data of dtype=np.float64 please set \ - predict_model='CPU' to use the CPU implementation \ - of predict.") - + warnings.warn("GPU based predict only accepts " + "np.float32 data. The model was " + "trained on np.float64 data hence " + "cannot use GPU-based prediction! " + "\nDefaulting to CPU-based Prediction. " + "\nTo predict on float-64 data, set " + "parameter predict_model = 'CPU'") + return self._predict_model_on_cpu(X, convert_dtype=convert_dtype) treelite_handle = self._obtain_treelite_handle() storage_type = \ @@ -365,6 +361,7 @@ class BaseRandomForestModel(Base): self.treelite_serialized_model = None super().set_params(**params) + return self def _check_fil_parameter_validity(depth, algo, fil_sparse_format): diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index f68bee6088..e6006ddf55 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -16,7 +16,6 @@ # # distutils: language = c++ - import numpy as np import rmm import warnings @@ -176,13 +175,11 @@ class RandomForestClassifier(BaseRandomForestModel, ----------- n_estimators : int (default = 100) Number of trees in the forest. (Default changed to 100 in cuML 0.11) - split_criterion : The criterion used to split nodes. - 0 for GINI, 1 for ENTROPY - 2 and 3 not valid for classification - (default = 0) - split_algo : int (default = 1) - Deprecated and currrently has no effect. - .. deprecated:: 21.06 + split_criterion : int or string (default = 0 ('gini')) + The criterion used to split nodes. + 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, + 2 or 'mse' for MSE + 2 or 'mse' not valid for classification bootstrap : boolean (default = True) Control bootstrapping. If True, each tree in the forest is built @@ -226,9 +223,6 @@ class RandomForestClassifier(BaseRandomForestModel, min_impurity_decrease : float (default = 0.0) Minimum decrease in impurity requried for node to be spilt. - use_experimental_backend : boolean (default = True) - Deprecated and currrently has no effect. - .. deprecated:: 21.08 max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. random_state : int (default = None) @@ -559,8 +553,7 @@ class RandomForestClassifier(BaseRandomForestModel, @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) def predict(self, X, predict_model="GPU", threshold=0.5, - algo='auto', num_classes=None, - convert_dtype=True, + algo='auto', convert_dtype=True, fil_sparse_format='auto') -> CumlArray: """ Predicts the labels for X. @@ -589,13 +582,6 @@ class RandomForestClassifier(BaseRandomForestModel, threshold : float (default = 0.5) Threshold used for classification. Optional and required only while performing the predict operation on the GPU. - num_classes : int (default = None) - number of different classes present in the dataset. - - .. deprecated:: 0.16 - Parameter 'num_classes' is deprecated and will be removed in - an upcoming version. The number of classes passed must match - the number of classes the model was trained on. convert_dtype : bool, optional (default = True) When set to True, the predict method will, when necessary, convert @@ -617,24 +603,19 @@ class RandomForestClassifier(BaseRandomForestModel, y : {} """ nvtx_range_push("predict RF-Classifier @randomforestclassifier.pyx") - if num_classes: - warnings.warn("num_classes is deprecated and will be removed" - " in an upcoming version") - if num_classes != self.num_classes: - raise NotImplementedError("limiting num_classes for predict" - " is not implemented") if predict_model == "CPU": preds = self._predict_model_on_cpu(X, convert_dtype=convert_dtype) - elif self.dtype == np.float64: - raise TypeError("GPU based predict only accepts np.float32 data. \ - In order use the GPU predict the model should \ - also be trained using a np.float32 dataset. \ - If you would like to use np.float64 dtype \ - then please use the CPU based predict by \ - setting predict_model = 'CPU'") - + warnings.warn("GPU based predict only accepts " + "np.float32 data. The model was " + "trained on np.float64 data hence " + "cannot use GPU-based prediction! " + "\nDefaulting to CPU-based Prediction. " + "\nTo predict on float-64 data, set " + "parameter predict_model = 'CPU'") + preds = self._predict_model_on_cpu(X, + convert_dtype=convert_dtype) else: preds = \ self._predict_model_on_gpu(X=X, output_class=True, @@ -650,7 +631,7 @@ class RandomForestClassifier(BaseRandomForestModel, @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) def predict_proba(self, X, algo='auto', - num_classes=None, convert_dtype=True, + convert_dtype=True, fil_sparse_format='auto') -> CumlArray: """ Predicts class probabilites for X. This function uses the GPU @@ -673,14 +654,6 @@ class RandomForestClassifier(BaseRandomForestModel, * ``'batch_tree_reorg'`` is used for dense storage and 'naive' for sparse storage - num_classes : int (default = None) - number of different classes present in the dataset. - - .. deprecated:: 0.16 - Parameter 'num_classes' is deprecated and will be removed in - an upcoming version. The number of classes passed must match - the number of classes the model was trained on. - convert_dtype : bool, optional (default = True) When set to True, the predict method will, when necessary, convert the input to the data type which was used to train the model. This @@ -708,15 +681,6 @@ class RandomForestClassifier(BaseRandomForestModel, then please use the CPU based predict by \ setting predict_model = 'CPU'") - if num_classes: - warnings.warn("num_classes is deprecated and will be removed" - " in an upcoming version") - if num_classes != self.num_classes: - raise NotImplementedError("The number of classes in the test " - "dataset should be equal to the " - "number of classes present in the " - "training dataset.") - preds_proba = \ self._predict_model_on_gpu(X, output_class=True, algo=algo, @@ -729,7 +693,7 @@ class RandomForestClassifier(BaseRandomForestModel, @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)'), ('dense_intdtype', '(n_samples, 1)')]) def score(self, X, y, threshold=0.5, - algo='auto', num_classes=None, predict_model="GPU", + algo='auto', predict_model="GPU", convert_dtype=True, fil_sparse_format='auto'): """ Calculates the accuracy metric score of the model for X. @@ -755,13 +719,6 @@ class RandomForestClassifier(BaseRandomForestModel, threshold is used to for classification This is optional and required only while performing the predict operation on the GPU. - num_classes : int (default = None) - number of different classes present in the dataset. - - .. deprecated:: 0.16 - Parameter 'num_classes' is deprecated and will be removed in - an upcoming version. The number of classes passed must match - the number of classes the model was trained on. convert_dtype : boolean, default=True whether to convert input data to correct dtype automatically @@ -803,7 +760,6 @@ class RandomForestClassifier(BaseRandomForestModel, threshold=threshold, algo=algo, convert_dtype=convert_dtype, predict_model=predict_model, - num_classes=num_classes, fil_sparse_format=fil_sparse_format) cdef uintptr_t preds_ptr diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index c96ff64eb6..11d1517eb8 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -161,22 +161,11 @@ class RandomForestRegressor(BaseRandomForestModel, ----------- n_estimators : int (default = 100) Number of trees in the forest. (Default changed to 100 in cuML 0.11) - split_algo : int (default = 1) - The algorithm to determine how nodes are split in the tree. - Can be changed only for the old backend [deprecated]. - 0 for HIST and 1 for GLOBAL_QUANTILE. Default is GLOBAL_QUANTILE. - The default backend does not support HIST. - HIST currently uses a slower tree-building algorithm so - GLOBAL_QUANTILE is recommended for most cases. - - .. deprecated:: 21.06 - Parameter 'split_algo' is deprecated and will be removed in - subsequent release. - split_criterion : int (default = 2) + split_criterion : int or string (default = 2 ('mse')) The criterion used to split nodes. - 0 for GINI, 1 for ENTROPY, - 2 for MSE - 0 and 1 not valid for regression + 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, + 2 or 'mse' for MSE + only 2 or 'mse' valid for regression bootstrap : boolean (default = True) Control bootstrapping. If True, each tree in the forest is built @@ -228,9 +217,6 @@ class RandomForestRegressor(BaseRandomForestModel, for median of abs error : 'median_ae' for mean of abs error : 'mean_ae' for mean square error' : 'mse' - use_experimental_backend : boolean (default = True) - Deprecated and currrently has no effect. - .. deprecated:: 21.08 max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. random_state : int (default = None) @@ -586,15 +572,16 @@ class RandomForestRegressor(BaseRandomForestModel, nvtx_range_push("predict RF-Regressor @randomforestregressor.pyx") if predict_model == "CPU": preds = self._predict_model_on_cpu(X, convert_dtype) - elif self.dtype == np.float64: - raise TypeError("GPU based predict only accepts np.float32 data. \ - In order use the GPU predict the model should \ - also be trained using a np.float32 dataset. \ - If you would like to use np.float64 dtype \ - then please use the CPU based predict by \ - setting predict_model = 'CPU'") - + warnings.warn("GPU based predict only accepts " + "np.float32 data. The model was " + "trained on np.float64 data hence " + "cannot use GPU-based prediction! " + "\nDefaulting to CPU-based Prediction. " + "\nTo predict on float-64 data, set " + "parameter predict_model = 'CPU'") + preds = self._predict_model_on_cpu(X, + convert_dtype=convert_dtype) else: preds = self._predict_model_on_gpu( X=X, diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index 028b2897c8..6d4ebf6813 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -188,7 +188,7 @@ def test_accuracy(nrows, ncols, n_info, datatype): # Initialize, fit and predict using cuML's # random forest classification model cuml_model = curfc(max_features=1.0, - n_bins=8, split_algo=0, split_criterion=0, + n_bins=8, split_criterion=0, min_samples_leaf=2, n_estimators=40, handle=handle, max_leaves=-1, max_depth=16) diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index 3d7c28774a..db038466e6 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import warnings import cudf import numpy as np import pytest import random import json -import io import os -from contextlib import redirect_stdout from numba import cuda @@ -219,9 +218,7 @@ def test_rf_classification(small_clf, datatype, max_samples, max_features): max_leaves=-1, max_depth=16, ) - f = io.StringIO() - with redirect_stdout(f): - cuml_model.fit(X_train, y_train) + cuml_model.fit(X_train, y_train) fil_preds = cuml_model.predict( X_test, predict_model="GPU", threshold=0.5, algo="auto" @@ -400,11 +397,22 @@ def test_rf_classification_float64(small_clf, datatype, convert_dtype): fil_acc = accuracy_score(y_test, fil_preds) assert fil_acc >= (cu_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa - else: - with pytest.raises(TypeError): + # if GPU predict cannot be used, display warning and use CPU predict + elif datatype[1] == np.float64: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") fil_preds = cuml_model.predict( - X_test, predict_model="GPU", convert_dtype=convert_dtype + X_test, predict_model="GPU", + convert_dtype=convert_dtype ) + assert("GPU based predict only accepts " + "np.float32 data. The model was " + "trained on np.float64 data hence " + "cannot use GPU-based prediction! " + "\nDefaulting to CPU-based Prediction. " + "\nTo predict on float-64 data, set " + "parameter predict_model = 'CPU'" + in str(w[-1].message)) @pytest.mark.parametrize( @@ -447,10 +455,21 @@ def test_rf_regression_float64(large_reg, datatype): assert fil_r2 >= (cu_r2 - 0.02) # because datatype[0] != np.float32 or datatype[0] != datatype[1] - with pytest.raises(TypeError): - fil_preds = cuml_model.predict( - X_test, predict_model="GPU", convert_dtype=False - ) + # display warning when GPU-predict cannot be used and revert to CPU-predict + elif datatype[1] == np.float64: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + fil_preds = cuml_model.predict( + X_test, predict_model="GPU" + ) + assert("GPU based predict only accepts " + "np.float32 data. The model was " + "trained on np.float64 data hence " + "cannot use GPU-based prediction! " + "\nDefaulting to CPU-based Prediction. " + "\nTo predict on float-64 data, set " + "parameter predict_model = 'CPU'" + in str(w[-1].message)) def check_predict_proba(test_proba, baseline_proba, y_test, rel_err): @@ -1125,9 +1144,7 @@ def test_concat_memory_leak(large_clf, estimator_type): assert (used_mem - initial_baseline_mem) < 1e6 -@pytest.mark.xfail(strict=True, raises=ValueError) def test_rf_nbins_small(small_clf): - X, y = small_clf X = X.astype(np.float32) y = y.astype(np.int32) @@ -1137,7 +1154,15 @@ def test_rf_nbins_small(small_clf): # Initialize, fit and predict using cuML's # random forest classification model cuml_model = curfc() - cuml_model.fit(X_train[0:3, :], y_train[0:3]) + + # display warning when nbins less than samples + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + cuml_model.fit(X_train[0:3, :], y_train[0:3]) + assert("The number of bins, `n_bins` is greater than " + "the number of samples used for training. " + "Changing `n_bins` to number of training samples." + in str(w[-1].message)) @pytest.mark.parametrize("split_criterion", [2], ids=["mse"])