Skip to content

Commit

Permalink
Kernel shap improvements (#5187)
Browse files Browse the repository at this point in the history
Removed slow modulo operator by minor change in index arithmetic. This gave me following performance improvement for a test case:

|                         | branch-23.02     |kernel-shap-improvments  | Gain |
|-------------------------|------------------|-------------------------|------|
| sampled_rows_kernel     | 663              | 193                     | 3.4x |
| exact_rows_kernel       | 363              | 236                     | 1.5x |

All times in microseconds.

Code used for benchmarking:
```python
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor as rf
from cuml.explainer import KernelExplainer

import numpy as np

data, labels = make_classification(n_samples=1000, n_features=20, n_informative=20,  random_state=42,
  n_redundant=0, n_repeated=0)

X_train, X_test, y_train, y_test = train_test_split(data, labels, train_size=998,
                                                    random_state=42) #sklearn train_test_split
y_train = np.ravel(y_train)
y_test = np.ravel(y_test)

model = rf(random_state=42).fit(X_train, y_train)
cu_explainer = KernelExplainer(model=model.predict, data=X_train, is_gpu_model=False, random_state=42, nsamples=100)
cu_shap_values = cu_explainer.shap_values(X_test)
print('cu_shap:', cu_shap_values)

```

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5187
  • Loading branch information
vinaydes authored Feb 8, 2023
1 parent 8d2e291 commit 5773725
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions cpp/src/explainer/kernel_shap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ __global__ void exact_rows_kernel(float* X,
int curr_X = (int)X[row + col];

// Iterate over nrows_background
for (int row_idx = blockIdx.x * nrows_background;
row_idx < blockIdx.x * nrows_background + nrows_background;
row_idx += 1) {
int row_idx_base = blockIdx.x * nrows_background;

for (int r = 0; r < nrows_background; r++) {
int row_idx = row_idx_base + r;
if (curr_X == 0) {
dataset[row_idx * ncols + col] = background[(row_idx % nrows_background) * ncols + col];
dataset[row_idx * ncols + col] = background[r * ncols + col];
} else {
dataset[row_idx * ncols + col] = observation[col];
}
Expand Down Expand Up @@ -139,26 +140,25 @@ __global__ void sampled_rows_kernel(IdxT* nsamples,
int curr_X = (int)X[2 * blockIdx.x * ncols + col_idx];
X[(2 * blockIdx.x + 1) * ncols + col_idx] = 1 - curr_X;

for (int bg_row_idx = 2 * blockIdx.x * nrows_background;
bg_row_idx < 2 * blockIdx.x * nrows_background + nrows_background;
bg_row_idx += 1) {
int bg_row_idx_base = 2 * blockIdx.x * nrows_background;

for (int r = 0; r < nrows_background; r++) {
int bg_row_idx = bg_row_idx_base + r;
if (curr_X == 0) {
dataset[bg_row_idx * ncols + col_idx] =
background[(bg_row_idx % nrows_background) * ncols + col_idx];
dataset[bg_row_idx * ncols + col_idx] = background[r * ncols + col_idx];
} else {
dataset[bg_row_idx * ncols + col_idx] = observation[col_idx];
}
}

for (int bg_row_idx = (2 * blockIdx.x + 1) * nrows_background;
bg_row_idx < (2 * blockIdx.x + 1) * nrows_background + nrows_background;
bg_row_idx += 1) {
bg_row_idx_base = 2 * (blockIdx.x + 1) * nrows_background;

for (int r = 0; r < nrows_background; r++) {
int bg_row_idx = bg_row_idx_base + r;
if (curr_X == 0) {
dataset[bg_row_idx * ncols + col_idx] = observation[col_idx];
} else {
// if(threadIdx.x == 0) printf("tid bg_row_idx: %d %d\n", tid, bg_row_idx);
dataset[bg_row_idx * ncols + col_idx] =
background[(bg_row_idx) % nrows_background * ncols + col_idx];
dataset[bg_row_idx * ncols + col_idx] = background[r * ncols + col_idx];
}
}

Expand Down

0 comments on commit 5773725

Please sign in to comment.