diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 1b89def8de..77d457e98e 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -36,6 +36,7 @@ from cuml.ensemble.randomforest_shared import treelite_serialize, \ from cuml.ensemble.randomforest_shared cimport * from cuml.common import input_to_cuml_array from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.prims.label.classlabels import make_monotonic, check_labels class BaseRandomForestModel(Base): @@ -283,11 +284,10 @@ class BaseRandomForestModel(Base): " `int32`") self.classes_ = cp.unique(y_m) self.num_classes = len(self.classes_) - for i in range(self.num_classes): - if i not in self.classes_: - raise ValueError("The labels need " - "to be consecutive values from " - "0 to the number of unique label values") + self.use_monotonic = not check_labels( + y_m, cp.arange(self.num_classes, dtype=np.int32)) + if self.use_monotonic: + y_m, _ = make_monotonic(y_m) else: y_m, _, _, y_dtype = \ diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 99d9cf69c0..7f1a79689d 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -46,6 +46,7 @@ from libc.stdint cimport uintptr_t, uint64_t from libc.stdlib cimport calloc, malloc, free from numba import cuda +from cuml.prims.label.classlabels import check_labels, invert_labels from pylibraft.common.handle cimport handle_t cimport cuml.common.cuda @@ -431,6 +432,8 @@ class RandomForestClassifier(BaseRandomForestModel, X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y, convert_dtype) + # Track the labels to see if update is necessary + self.update_labels = not check_labels(y_m, self.classes_) cdef uintptr_t X_ptr, y_ptr X_ptr = X_m.ptr @@ -611,6 +614,9 @@ class RandomForestClassifier(BaseRandomForestModel, fil_sparse_format=fil_sparse_format, predict_proba=False) + if self.update_labels: + preds = preds.to_output().astype(self.classes_.dtype) + preds = invert_labels(preds, self.classes_) return preds @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py index e0997a1630..0dd3e501d9 100644 --- a/python/cuml/tests/test_random_forest.py +++ b/python/cuml/tests/test_random_forest.py @@ -307,6 +307,64 @@ def test_rf_classification(small_clf, datatype, max_samples, max_features): assert fil_acc >= (cuml_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa +@pytest.mark.parametrize( + "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)] +) +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +def test_rf_classification_unorder( + small_clf, datatype, max_samples, max_features=1, a=2, b=5): + use_handle = True + + X, y = small_clf + X = X.astype(datatype) + y = y.astype(np.int32) + # affine transformation + y = a*y+b + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) + # Create a handle for the cuml model + handle, stream = get_handle(use_handle, n_streams=1) + + # Initialize, fit and predict using cuML's + # random forest classification model + cuml_model = curfc( + max_features=max_features, + max_samples=max_samples, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=123, + n_streams=1, + n_estimators=40, + handle=handle, + max_leaves=-1, + max_depth=16, + ) + cuml_model.fit(X_train, y_train) + + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", threshold=0.5, algo="auto" + ) + cu_preds = cuml_model.predict(X_test, predict_model="CPU") + fil_preds = np.reshape(fil_preds, np.shape(cu_preds)) + cuml_acc = accuracy_score(y_test, cu_preds) + fil_acc = accuracy_score(y_test, fil_preds) + if X.shape[0] < 500000: + sk_model = skrfc( + n_estimators=40, + max_depth=16, + min_samples_split=2, + max_features=max_features, + random_state=10, + ) + sk_model.fit(X_train, y_train) + sk_preds = sk_model.predict(X_test) + sk_acc = accuracy_score(y_test, sk_preds) + assert fil_acc >= (sk_acc - 0.07) + assert fil_acc >= (cuml_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa + + @pytest.mark.parametrize( "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)] )