-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuml.experimental SHAP improvements #3433
Conversation
…18-fea-kshap-opt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Only small comments
Codecov Report
@@ Coverage Diff @@
## branch-0.18 #3433 +/- ##
===============================================
+ Coverage 71.48% 71.55% +0.06%
===============================================
Files 207 212 +5
Lines 16748 17082 +334
===============================================
+ Hits 11973 12223 +250
- Misses 4775 4859 +84
Continue to review full report at Codecov.
|
rerun tests |
3 similar comments
rerun tests |
rerun tests |
rerun tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I had only small suggestions, mostly stuff that can be deferred to future PRs since this is still experimental. My only concern is that the number of variations supported in datatypes (sklearn model with pandas background data or cuml with f-ordered numpy or ...) makes it hard to test all paths of the base shap initialization. Let's look at codecov there for additional test ideas and be open to simplifying the options if necessary.
|
||
void shap_main_effect_dataset "ML::Explainer::shap_main_effect_dataset"( | ||
const handle_t& handle, | ||
float* dataset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(in the underlying API) should dataset be const?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset is where the output is generated, maybe I should change the name to avoid the confusion?
---------- | ||
model : function | ||
Function that takes a matrix of samples (n_samples, n_features) and | ||
computes the output for those samples with shape (n_samples). Function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, bummer so there is no way to use the tags api because we need to take the function rather than the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already use the tags API by getting the owning object of the function (if it exists) and getting tags from that:
def get_tag_from_model_func(func, tag, default=None): |
default=np.float32) | ||
else: | ||
if dtype in [np.float32, np.float64]: | ||
self.dtype = np.dtype(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity why do you have to convert to np.dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was doing the wrong order of things, I use the dtype
function of numpy so that we accept string description of the dtypes without additional work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good - just some doc and test suggestions. I think there is still a california_housing test coming? We could split that to the next PR too.
@@ -213,13 +198,17 @@ class SHAPBase(): | |||
) | |||
) | |||
|
|||
# public attribute saved as NumPy for compatibility with the legacy | |||
# SHAP potting functions | |||
self.expected_value = cp.asnumpy(self._expected_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, but it's a deviation from our standard approach... can you add a docstring to explain this? Can be a follow up PR.
Also would be really good to have a test of compatibility with SHAP plotting so we never break this (again, follow on PR ok)
…d other small fixes
rerun tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-approving with some small suggestions/questions. Looks great!
// gemv, which could cause a very sporadic race condition in Pascal and | ||
// Turing GPUs that caused it to give the wrong results. Details: | ||
// https://github.com/rapidsai/cuml/issues/1739 | ||
rmm::device_uvector<math_t> tmp_vector(n_cols, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about something like tmp_gemv_result
or otherwise indicating its use?
def output_list_shap_values(X, dimensions, output_type): | ||
if output_type == 'cupy': | ||
if dimensions == 1: | ||
return X[0] | ||
else: | ||
return X | ||
res = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super picky but this seems like either a list comprehension or just list(X)
would be nicer
@@ -399,6 +416,11 @@ def test_l1_regularization(exact_tests_dataset, l1_type): | |||
0.00088981] | |||
] | |||
|
|||
housing_regression_result = np.array( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this obtained by running shap? Would be good to note in a comment what you did to get it and what version you used.
rerun tests |
@gpucibot merge |
Closes #1739
Addresses most items of #3224