-
Notifications
You must be signed in to change notification settings - Fork 552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update predict() / predict_proba() of RF to match sklearn #3609
Update predict() / predict_proba() of RF to match sklearn #3609
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just one small question.
@@ -271,7 +271,7 @@ def fit(self, X, y, convert_dtype=False): | |||
convert_dtype=convert_dtype) | |||
return self | |||
|
|||
def predict(self, X, output_class=True, algo='auto', threshold=0.5, | |||
def predict(self, X, algo='auto', threshold=0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we remove this threshold too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(it's not in sklearn)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is fine, since predict()
produces class predictions.
@gpucibot merge |
Closes #3682 In #3609, I unintentionally broke the function `score()` in the random forest. This PR restores the functionality. In addition, I added `score()` to one of the unit tests to ensure that `score()` does not break again. Authors: - Philip Hyunsu Cho (https://github.com/hcho3) Approvers: - John Zedlewski (https://github.com/JohnZed) URL: #3685
Closes #3347.
Make the
predict()
andpredict_proba()
functions of RF to match those in the scikit-learn RF.output_class
. Instead,predict()
will always produce class prediction, andpredict_proba()
will always produce probability prediction. (This applies to binary and multi-class classifiers. Regressors will only havepredict()
.)threshold
parameter frompredict_proba()
.