-
Notifications
You must be signed in to change notification settings - Fork 540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] Add predict_proba() to XGBoost-style models in FIL C++ #2894
Conversation
Please update the changelog in order to start CI tests. View the gpuCI docs here. |
To copy the Slack response: I definitely agree that there can be a neat single template for softmax, and I will implement the suggestions above |
using BlockReduceHost = | ||
typename cub::BlockReduce<vec<NITEMS, float>, FIL_TPB, | ||
cub::BLOCK_REDUCE_WARP_REDUCTIONS, 1, 1, 600>; | ||
size_t block_reduce_footprint_host() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was changing half the cases where this was used, as well as deleting the device-side template using
statements. This meant the host ones were left alone and did not need to define class separately from footprint. This allowed to implement smem footprint additions for GROVE_PER_CLASS_* in a much more readable way (see below).
now conflicts with #3088 for combo ML::fil::output_t combo values (but a simple update to resolve merge conflict) |
This reverts commit 8581222.
depends on #3582 |
# FIL doesn't yet support predict_proba() for multi-class | ||
# TODO: Add a test for predict_proba() when it's supported | ||
gbm_preds = bst.predict(X) | ||
gbm_preds = gbm_preds.argmax(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since lightgbm doesn't support probabilities without the sklearn API, using it here for both predictions
rerun tests |
@JohnZed PR looks good to me, ready to merge and has addressed all feedback, when you have a sec could you take a look? |
@gpucibot merge |
Codecov Report
@@ Coverage Diff @@
## branch-0.19 #2894 +/- ##
===============================================
+ Coverage 80.70% 80.74% +0.03%
===============================================
Files 227 227
Lines 17615 17737 +122
===============================================
+ Hits 14217 14322 +105
- Misses 3398 3415 +17
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
No description provided.