-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support predict in MNMG Logistic Regression #5516
Conversation
Pull requests from external contributors require approval from a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, @lijinf2. I think it looks great overall and it's nice to be able both train and predict! Just one concern/question, really.
@@ -31,7 +32,8 @@ | |||
np = cpu_only_import("numpy") | |||
|
|||
|
|||
class LogisticRegression(BaseEstimator, SyncFitMixinLinearModel): | |||
# class LogisticRegression(BaseEstimator, SyncFitMixinLinearModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove these two mixins?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mixins have been included in the cuml.dask.linear_model.LinearRegression. See the definition: Class LinearRegression(
BaseEstimator, SyncFitMixinLinearModel, DelayedPredictionMixin)
The next line LogisticRegression(LinearRegression) automatically inherits the two mixins, and reuses the code/functions of the LinearRegression class.
/merge |
The init arguments are for LBFGS (the only algorithm in the current MNMG LogisticRegression). The key code changes should be a few lines after [PR 5516 for predict](#5516) gets merged. Key code changes can be reviewed from [here](https://github.com/rapidsai/cuml/pull/5519/files/d058d884c992661984224d0190c3bbcc0a23caf4..fbbaa5c6aef47ddc7100f5bea2a751851ca6d1b4) Authors: - Jinfeng Li (https://github.com/lijinf2) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #5519
This is a followup PR for PR 5477. This PR adds predict API to MNMG logistic regression and tests to verify the correctness.
Please review the code change from commit 171aef2 with message "add predict operator". The implementation is trivial after the dependency PR 5477 is merged.