Skip to content

Commit

Permalink
FIX Accept small floats in RandomForest (rapidsai#4717)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored May 11, 2022
1 parent f7ccdca commit 1f246d6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class BaseRandomForestModel(Base):
math.ceil(self.min_samples_leaf * self.n_rows)
if type(self.min_samples_split) == float:
self.min_samples_split = \
math.ceil(self.min_samples_split * self.n_rows)
max(2, math.ceil(self.min_samples_split * self.n_rows))
return X_m, y_m, max_feature_val

def _tl_handle_from_bytes(self, treelite_serialized_model):
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class RandomForestClassifier(BaseRandomForestModel,
* If type ``int``, then min_samples_split represents the minimum
number.
* If type ``float``, then ``min_samples_split`` represents a fraction
and ``ceil(min_samples_split * n_rows)`` is the minimum number of
samples for each split.
and ``max(2, ceil(min_samples_split * n_rows))`` is the minimum
number of samples for each split.
min_impurity_decrease : float (default = 0.0)
Minimum decrease in impurity requried for
node to be spilt.
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class RandomForestRegressor(BaseRandomForestModel,
* If type ``int``, then min_samples_split represents the minimum
number.
* If type ``float``, then ``min_samples_split`` represents a fraction
and ``ceil(min_samples_split * n_rows)`` is the minimum number of
samples for each split.
and ``max(2, ceil(min_samples_split * n_rows))`` is the minimum
number of samples for each split.
min_impurity_decrease : float (default = 0.0)
The minimum decrease in impurity required for node to be split
accuracy_metric : string (default = 'r2')
Expand Down
14 changes: 14 additions & 0 deletions python/cuml/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,3 +1302,17 @@ def test_rf_multiclass_classifier_gtil_integration(tmpdir):
tl_model = treelite.Model.deserialize(checkpoint_path)
out_prob = treelite.gtil.predict(tl_model, X, pred_margin=True)
np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)


@pytest.mark.parametrize("estimator, make_data", [
(curfc, make_classification),
(curfr, make_regression),
])
def test_rf_min_samples_split_with_small_float(estimator, make_data):
# Check that min_samples leaf is works with a small float
# Non-regression test for gh-4613
X, y = make_data(random_state=0)
clf = estimator(min_samples_split=0.0001, random_state=0, n_estimators=2)

# Does not error
clf.fit(X, y)

0 comments on commit 1f246d6

Please sign in to comment.