Skip to content

Commit

Permalink
cuml.experimental SHAP improvements (#3433)
Browse files Browse the repository at this point in the history
Closes #1739 

Addresses most items of #3224

Authors:
  - Dante Gama Dessavre (@dantegd)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3433
  • Loading branch information
dantegd authored Feb 10, 2021
1 parent ec462e7 commit 8082f3b
Show file tree
Hide file tree
Showing 16 changed files with 659 additions and 409 deletions.
2 changes: 1 addition & 1 deletion cpp/include/cuml/explainer/kernel_shap.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuml/explainer/permutation_shap.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/explainer/kernel_shap.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
25 changes: 23 additions & 2 deletions cpp/src/explainer/permutation_shap.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,17 +27,30 @@ __global__ void _fused_tile_scatter_pe(DataT* dataset, const DataT* background,
bool row_major) {
// kernel that actually does the scattering as described in the
// descriptions of `permutation_dataset` and `shap_main_effect_dataset`
// parameter sc_size allows us to generate both the permuation_shap_dataset
// and the main_effect_dataset with the same kernel, since they do the
// scattering in the same manner, its just the "height" of the columns
// generated from values that is different.
IdxT tid = threadIdx.x + blockDim.x * blockIdx.x;

if (tid < ncols * nrows_dataset) {
IdxT row, col, start, end;

if (row_major) {
row = tid / ncols;

// we calculate the first row where the entry of dataset will be
// entered into background depending on its place in the index array
col = idx[tid % ncols];
start = ((tid % ncols) + 1) * nrows_background;

// each entry of the dataset will be input the same number of times
// to the matrix, controled by the sc_size parameter
end = start + sc_size * nrows_background;

// now we just need to check if this thread is between start and end
// if it is then the value should be based on the observation obs
// otherwise on the background dataset
if ((start <= row && row < end)) {
dataset[row * ncols + col] = obs[col];
} else {
Expand All @@ -49,7 +62,11 @@ __global__ void _fused_tile_scatter_pe(DataT* dataset, const DataT* background,
col = tid / nrows_dataset;
row = tid % nrows_dataset;

// main difference between row and col major is how do we calculate
// the end and start and which row corresponds to each thread
start = nrows_background + idx[col] * nrows_background;

// calculation of end position is identical
end = start + sc_size * nrows_background;

if ((start <= row && row < end)) {
Expand Down Expand Up @@ -78,6 +95,8 @@ void permutation_shap_dataset_impl(const raft::handle_t& handle, DataT* dataset,

IdxT nblks = (nrows_dataset * ncols + nthreads - 1) / nthreads;

// each thread calculates a single element
// for the permutation shap dataset we need the sc_size parameter to be ncols
_fused_tile_scatter_pe<<<nblks, nthreads, 0, stream>>>(
dataset, background, nrows_dataset, ncols, row, idx, nrows_background,
ncols, row_major);
Expand Down Expand Up @@ -107,13 +126,15 @@ void shap_main_effect_dataset_impl(const raft::handle_t& handle, DataT* dataset,
const auto& handle_impl = handle;
cudaStream_t stream = handle_impl.get_stream();

// we calculate the number of rows in the dataset
// we calculate the number of elements in the dataset
IdxT total_num_elements = (nrows_bg * ncols + nrows_bg) * ncols;

constexpr IdxT nthreads = 512;

IdxT nblks = (total_num_elements + nthreads - 1) / nthreads;

// each thread calculates a single element
// for the permutation shap dataset we need the sc_size parameter to be 1
_fused_tile_scatter_pe<<<nblks, nthreads, 0, stream>>>(
dataset, background, total_num_elements / ncols, ncols, row, idx, nrows_bg,
1, row_major);
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/glm/ols.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ void olsFit(const raft::handle_t &handle, math_t *input, int n_rows, int n_cols,
fit_intercept, normalize, stream);
}

if (algo == 0 || n_cols == 1) {
LinAlg::lstsqSVD(handle, input, n_rows, n_cols, labels, coef, stream);
} else if (algo == 1) {
LinAlg::lstsqEig(handle, input, n_rows, n_cols, labels, coef, stream);
if (algo == 0 || algo == 1) {
LinAlg::lstsq(handle, input, n_rows, n_cols, labels, coef, algo, stream);
} else if (algo == 2) {
LinAlg::lstsqQR(input, n_rows, n_cols, labels, coef, cusolver_handle,
cublas_handle, allocator, stream);
Expand Down
66 changes: 25 additions & 41 deletions cpp/src_prims/linalg/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/cudart_utils.h>
#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/linalg/gemv.h>
Expand All @@ -31,65 +32,48 @@
#include <raft/matrix/matrix.cuh>
#include <raft/mr/device/buffer.hpp>
#include <raft/random/rng.cuh>
#include <rmm/device_uvector.hpp>

namespace MLCommon {
namespace LinAlg {

template <typename math_t>
void lstsqSVD(const raft::handle_t &handle, math_t *A, int n_rows, int n_cols,
math_t *b, math_t *w, cudaStream_t stream) {
auto allocator = handle.get_device_allocator();
void lstsq(const raft::handle_t &handle, math_t *A, int n_rows, int n_cols,
math_t *b, math_t *w, int algo, cudaStream_t stream) {
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();

ASSERT(n_cols > 0, "lstsq: number of columns cannot be less than one");
ASSERT(n_rows > 1, "lstsq: number of rows cannot be less than two");

int U_len = n_rows * n_cols;
int V_len = n_cols * n_cols;

raft::mr::device::buffer<math_t> S(allocator, stream, n_cols);
raft::mr::device::buffer<math_t> V(allocator, stream, V_len);
raft::mr::device::buffer<math_t> U(allocator, stream, U_len);
raft::mr::device::buffer<math_t> UT_b(allocator, stream, n_rows);
rmm::device_uvector<math_t> S(n_cols, stream);
rmm::device_uvector<math_t> V(V_len, stream);
rmm::device_uvector<math_t> U(U_len, stream);

raft::linalg::svdQR(handle, A, n_rows, n_cols, S.data(), U.data(), V.data(),
true, true, true, stream);
// we use a temporary vector to avoid doing re-using w in the last step, the
// gemv, which could cause a very sporadic race condition in Pascal and
// Turing GPUs that caused it to give the wrong results. Details:
// https://github.com/rapidsai/cuml/issues/1739
rmm::device_uvector<math_t> tmp_vector(n_cols, stream);

raft::linalg::gemv(handle, U.data(), n_rows, n_cols, b, w, true, stream);
if (algo == 0 || n_cols == 1) {
raft::linalg::svdQR(handle, A, n_rows, n_cols, S.data(), U.data(), V.data(),
true, true, true, stream);
} else if (algo == 1) {
raft::linalg::svdEig(handle, A, n_rows, n_cols, S.data(), U.data(),
V.data(), true, stream);
}

raft::matrix::matrixVectorBinaryDivSkipZero(w, S.data(), 1, n_cols, false,
true, stream);
raft::linalg::gemv(handle, U.data(), n_rows, n_cols, b, tmp_vector.data(),
true, stream);

raft::linalg::gemv(handle, V.data(), n_cols, n_cols, w, w, false, stream);
}

template <typename math_t>
void lstsqEig(const raft::handle_t &handle, math_t *A, int n_rows, int n_cols,
math_t *b, math_t *w, cudaStream_t stream) {
auto allocator = handle.get_device_allocator();
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();

ASSERT(n_cols > 1, "lstsq: number of columns cannot be less than two");
ASSERT(n_rows > 1, "lstsq: number of rows cannot be less than two");

int U_len = n_rows * n_cols;
int V_len = n_cols * n_cols;

raft::mr::device::buffer<math_t> S(allocator, stream, n_cols);
raft::mr::device::buffer<math_t> V(allocator, stream, V_len);
raft::mr::device::buffer<math_t> U(allocator, stream, U_len);

raft::linalg::svdEig(handle, A, n_rows, n_cols, S.data(), U.data(), V.data(),
true, stream);

raft::linalg::gemv(handle, U.data(), n_rows, n_cols, b, w, true, stream);

raft::matrix::matrixVectorBinaryDivSkipZero(w, S.data(), 1, n_cols, false,
true, stream);
raft::matrix::matrixVectorBinaryDivSkipZero(tmp_vector.data(), S.data(), 1,
n_cols, false, true, stream);

raft::linalg::gemv(handle, V.data(), n_cols, n_cols, w, w, false, stream);
raft::linalg::gemv(handle, V.data(), n_cols, n_cols, tmp_vector.data(), w,
false, stream);
}

template <typename math_t>
Expand Down
186 changes: 0 additions & 186 deletions python/cuml/experimental/explainer/base.py

This file was deleted.

Loading

0 comments on commit 8082f3b

Please sign in to comment.