Skip to content

Commit

Permalink
Fix RandomForestClassifier return type (rapidsai#5896)
Browse files Browse the repository at this point in the history
Closes rapidsai#5637 

```
import cuml
from cuml.datasets import make_classification

X, y = make_classification()

clf = cuml.ensemble.RandomForestClassifier().fit(X,y)
print(clf.predict(X[:5]).dtype)
```

Result is

```
int64
```

Authors:
  - Jinsol Park (https://github.com/jinsolp)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#5896
  • Loading branch information
jinsolp authored May 23, 2024
1 parent 2a5185d commit 47416d7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -550,6 +550,7 @@ class RandomForestClassifier(BaseRandomForestModel,
domain="cuml_python")
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@cuml.internals.api_base_return_array(get_output_dtype=True)
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down
8 changes: 8 additions & 0 deletions python/cuml/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,3 +1382,11 @@ def test_rf_min_samples_split_with_small_float(estimator, make_data):

# Does not error
clf.fit(X, y)


def test_rf_predict_returns_int():

X, y = make_classification()
clf = cuml.ensemble.RandomForestClassifier().fit(X, y)
pred = clf.predict(X)
assert pred.dtype == np.int64

0 comments on commit 47416d7

Please sign in to comment.