-
Notifications
You must be signed in to change notification settings - Fork 550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Add weighted K-Means sampling for SHAP #4051
[REVIEW] Add weighted K-Means sampling for SHAP #4051
Conversation
Codecov Report
@@ Coverage Diff @@
## branch-21.08 #4051 +/- ##
===============================================
Coverage ? 85.80%
===============================================
Files ? 232
Lines ? 18314
Branches ? 0
===============================================
Hits ? 15714
Misses ? 2600
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
if output_dtype == cudf.DataFrame: | ||
group_names = X.columns | ||
X = X.values | ||
elif output_dtype == cudf.Series: | ||
group_names = X.name | ||
X = X.values.reshape(-1, 1) | ||
elif output_dtype == pd.DataFrame: | ||
group_names = X.columns | ||
X = cp.array(X.values) | ||
elif output_dtype == pd.Series: | ||
group_names = X.name | ||
X = cp.array(X.values.reshape(-1, 1)) | ||
else: | ||
# it's either numpy, cupy or numba | ||
if output_dtype == cuda.devicearray.DeviceNDArrayBase: | ||
X = cp.array(X) | ||
elif output_dtype == np.ndarray: | ||
X = cp.array(X) | ||
try: | ||
# more than one column | ||
group_names = [str(i) for i in range(X.shape[1])] | ||
except IndexError: | ||
# one column | ||
X = X.reshape(-1, 1) | ||
group_names = ['0'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code probably can be simplified further, but we can do that as a follow up PR for 21.10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened an issue: #4121
@gpucibot merge |
Adding sampling method for SHAP using k-means, adapted from https://github.com/slundberg/shap/blob/9411b68e8057a6c6f3621765b89b24d82bee13d4/shap/utils/_legacy.py Moving the code from interpret-community package for easier maintenance. Chose not to add comparison with SHAP as it will add a dependency to SHAP not sure if we want that. Closes rapidsai#4000 Authors: - Nanthini (https://github.com/Nanthini10) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#4051
Adding sampling method for SHAP using k-means, adapted from https://github.com/slundberg/shap/blob/9411b68e8057a6c6f3621765b89b24d82bee13d4/shap/utils/_legacy.py
Moving the code from interpret-community package for easier maintenance.
Chose not to add comparison with SHAP as it will add a dependency to SHAP not sure if we want that.
Closes #4000