Skip to content

Commit

Permalink
Transforms RandomForest estimators non-consecutive labels to consecut…
Browse files Browse the repository at this point in the history
…ive labels where appropriate (rapidsai#4780)

This PR closes rapidsai#4478 by transforming non-consecutive labels outside of [0,n) to consecutive labels inside [0,n) similar to what Scikit-learn does under the hood.

Closes rapidsai#691

Authors:
  - https://github.com/VamsiTallam95

Approvers:
  - Micka (https://github.com/lowener)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4780
  • Loading branch information
VamsiTallam95 authored Sep 29, 2022
1 parent 359aa6e commit 0c5988d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ from cuml.ensemble.randomforest_shared import treelite_serialize, \
from cuml.ensemble.randomforest_shared cimport *
from cuml.common import input_to_cuml_array
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.prims.label.classlabels import make_monotonic, check_labels


class BaseRandomForestModel(Base):
Expand Down Expand Up @@ -283,11 +284,10 @@ class BaseRandomForestModel(Base):
" `int32`")
self.classes_ = cp.unique(y_m)
self.num_classes = len(self.classes_)
for i in range(self.num_classes):
if i not in self.classes_:
raise ValueError("The labels need "
"to be consecutive values from "
"0 to the number of unique label values")
self.use_monotonic = not check_labels(
y_m, cp.arange(self.num_classes, dtype=np.int32))
if self.use_monotonic:
y_m, _ = make_monotonic(y_m)

else:
y_m, _, _, y_dtype = \
Expand Down
6 changes: 6 additions & 0 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ from libc.stdint cimport uintptr_t, uint64_t
from libc.stdlib cimport calloc, malloc, free

from numba import cuda
from cuml.prims.label.classlabels import check_labels, invert_labels

from pylibraft.common.handle cimport handle_t
cimport cuml.common.cuda
Expand Down Expand Up @@ -431,6 +432,8 @@ class RandomForestClassifier(BaseRandomForestModel,

X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y,
convert_dtype)
# Track the labels to see if update is necessary
self.update_labels = not check_labels(y_m, self.classes_)
cdef uintptr_t X_ptr, y_ptr

X_ptr = X_m.ptr
Expand Down Expand Up @@ -611,6 +614,9 @@ class RandomForestClassifier(BaseRandomForestModel,
fil_sparse_format=fil_sparse_format,
predict_proba=False)

if self.update_labels:
preds = preds.to_output().astype(self.classes_.dtype)
preds = invert_labels(preds, self.classes_)
return preds

@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
Expand Down
58 changes: 58 additions & 0 deletions python/cuml/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,64 @@ def test_rf_classification(small_clf, datatype, max_samples, max_features):
assert fil_acc >= (cuml_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa


@pytest.mark.parametrize(
"max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)]
)
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
def test_rf_classification_unorder(
small_clf, datatype, max_samples, max_features=1, a=2, b=5):
use_handle = True

X, y = small_clf
X = X.astype(datatype)
y = y.astype(np.int32)
# affine transformation
y = a*y+b
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.8, random_state=0
)
# Create a handle for the cuml model
handle, stream = get_handle(use_handle, n_streams=1)

# Initialize, fit and predict using cuML's
# random forest classification model
cuml_model = curfc(
max_features=max_features,
max_samples=max_samples,
n_bins=16,
split_criterion=0,
min_samples_leaf=2,
random_state=123,
n_streams=1,
n_estimators=40,
handle=handle,
max_leaves=-1,
max_depth=16,
)
cuml_model.fit(X_train, y_train)

fil_preds = cuml_model.predict(
X_test, predict_model="GPU", threshold=0.5, algo="auto"
)
cu_preds = cuml_model.predict(X_test, predict_model="CPU")
fil_preds = np.reshape(fil_preds, np.shape(cu_preds))
cuml_acc = accuracy_score(y_test, cu_preds)
fil_acc = accuracy_score(y_test, fil_preds)
if X.shape[0] < 500000:
sk_model = skrfc(
n_estimators=40,
max_depth=16,
min_samples_split=2,
max_features=max_features,
random_state=10,
)
sk_model.fit(X_train, y_train)
sk_preds = sk_model.predict(X_test)
sk_acc = accuracy_score(y_test, sk_preds)
assert fil_acc >= (sk_acc - 0.07)
assert fil_acc >= (cuml_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa


@pytest.mark.parametrize(
"max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)]
)
Expand Down

0 comments on commit 0c5988d

Please sign in to comment.