Skip to content

Commit

Permalink
Row major Gram matrices (#3639)
Browse files Browse the repository at this point in the history
This PR add row major Gram matrices. These will be used in SVM kernels to allow flexibility in the input layout (#2198).

For the benchmarked cases, row major input is around 2.5% slower on average.

![image](https://user-images.githubusercontent.com/3671106/111769429-8361b000-88a9-11eb-85e1-145caf5f42b2.png)

Authors:
  - Tamas Bela Feher (@tfeher)

Approvers:
  - Thejaswi. N. S (@teju85)
  - Dante Gama Dessavre (@dantegd)

URL: #3639
  • Loading branch information
tfeher authored Mar 31, 2021
1 parent 8316807 commit b10f31f
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 238 deletions.
13 changes: 9 additions & 4 deletions cpp/bench/prims/gram_matrix.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@ struct GramTestParams {
int k; // k parameter of the GEMM
int n; // n parameter of the GEMM
KernelParams kernel_params;
bool is_row_major;
}; // struct GramTestParams

template <typename T>
Expand All @@ -46,7 +47,8 @@ struct GramMatrix : public Fixture {
std::vector<std::string> kernel_names{"linear", "poly", "rbf", "tanh"};
std::ostringstream oss;
oss << name << "/" << kernel_names[p.kernel_params.kernel] << "/" << p.m
<< "x" << p.k << "x" << p.n;
<< "x" << p.k << "x" << p.n << "/"
<< (p.is_row_major ? "row_major" : "col_major");
this->SetName(oss.str().c_str());

CUBLAS_CHECK(cublasCreate(&cublas_handle));
Expand Down Expand Up @@ -78,7 +80,8 @@ struct GramMatrix : public Fixture {
}
loopOnState(state, [this]() {
(*this->kernel)(this->A, this->params.m, this->params.k, this->B,
this->params.n, this->C, this->stream);
this->params.n, this->C, this->params.is_row_major,
this->stream);
});
}

Expand Down Expand Up @@ -110,7 +113,9 @@ static std::vector<GramTestParams> getInputs() {
param_vec.reserve(kernel_params.size() * data_size.size());
for (TestSize s : data_size) {
for (auto kernel : kernel_params) {
param_vec.push_back(GramTestParams{s.m, s.k, s.n, kernel});
for (bool row_major : {false, true}) {
param_vec.push_back(GramTestParams{s.m, s.k, s.n, kernel, row_major});
}
}
}
return param_vec;
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/svm/kernelcache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ class KernelCache {
x_ws.data(), ws_idx_new,
non_cached, stream, false);
math_t *tile_new = tile.data() + (size_t)n_cached * n_rows;
(*kernel)(x, n_rows, n_cols, x_ws.data(), non_cached, tile_new, stream);
(*kernel)(x, n_rows, n_cols, x_ws.data(), non_cached, tile_new, false,
stream);
// We need AssignCacheIdx to be finished before calling StoreCols
cache.StoreVecs(tile_new, n_rows, non_cached,
ws_cache_idx.data() + n_cached, stream);
Expand All @@ -219,7 +220,7 @@ class KernelCache {
raft::matrix::copyRows<math_t, int, size_t>(
x, n_rows, n_cols, x_ws.data(), unique_idx.data(), n_unique, stream,
false);
(*kernel)(x, n_rows, n_cols, x_ws.data(), n_unique, tile.data(),
(*kernel)(x, n_rows, n_cols, x_ws.data(), n_unique, tile.data(), false,
stream);
}
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/svm/svc_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void svcPredict(const raft::handle_t &handle, math_t *input, int n_rows,
ld1 = n_rows;
}
kernel->evaluate(x_ptr, n_batch, n_cols, model.x_support, model.n_support,
K.data(), stream, ld1, model.n_support, n_batch);
K.data(), false, stream, ld1, model.n_support, n_batch);
math_t one = 1;
math_t null = 0;
CUBLAS_CHECK(raft::linalg::cublasgemv(
Expand Down
112 changes: 60 additions & 52 deletions cpp/src_prims/matrix/grammatrix.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,8 +30,10 @@ namespace Matrix {
* and X2.
*
* To be more precise, on exit the output buffer will store:
* out[j+k*n1] = <x1_j, x2_k> where x1_j is the j-th vector from the x1 set
* and x2_k is the k-th vector from the x2 set.
* - if is_row_major == true: out[j+k*n1] = <x1_j, x2_k>,
* - if is_row_major == false: out[j*n2 + k] = <x1_j, x2_k>,
* where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector
* from the x2 set.
*/
template <typename math_t>
class GramMatrixBase {
Expand All @@ -44,56 +46,55 @@ class GramMatrixBase {

/** Convenience function to evaluate the Gram matrix for two vector sets.
*
* @param [in] x1 device array of vectors in column major format,
* size [n1*n_cols]
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors in column major format,
* size [n2*n_cols]
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix in column major
* format, size [n1*n2]
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
virtual void operator()(const math_t *x1, int n1, int n_cols,
const math_t *x2, int n2, math_t *out,
cudaStream_t stream, int ld1 = 0, int ld2 = 0,
int ld_out = 0) {
bool is_row_major, cudaStream_t stream, int ld1 = 0,
int ld2 = 0, int ld_out = 0) {
if (ld1 <= 0) {
ld1 = n1;
ld1 = is_row_major ? n_cols : n1;
}
if (ld2 <= 0) {
ld2 = n2;
ld2 = is_row_major ? n_cols : n2;
}
if (ld_out <= 0) {
ld_out = n1;
ld_out = is_row_major ? n2 : n1;
}
evaluate(x1, n1, n_cols, x2, n2, out, stream, ld1, ld2, ld_out);
evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2,
ld_out);
}

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] x1 device array of vectors in column major format,
* size [n1*n_cols]
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors in column major format,
* size [n2*n_cols]
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix in column major
* format, size [n1*n2]
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
*/
virtual void evaluate(const math_t *x1, int n1, int n_cols, const math_t *x2,
int n2, math_t *out, cudaStream_t stream, int ld1,
int ld2, int ld_out) {
linear(x1, n1, n_cols, x2, n2, out, stream, ld1, ld2, ld_out);
int n2, math_t *out, bool is_row_major,
cudaStream_t stream, int ld1, int ld2, int ld_out) {
linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out);
}

//private:
Expand All @@ -102,58 +103,65 @@ class GramMatrixBase {
// __device__ lambda cannot have private or protected access within its class"

/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] x1 device array of vectors in column major format,
* size [n1*n_cols]
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of colums (features) in x1 and x2
* @param [in] x2 device array of vectors in column major format,
* size [n2*n_cols]
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix in column major
* format, size [n1*n2]
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
void linear(const math_t *x1, int n1, int n_cols, const math_t *x2, int n2,
math_t *out, cudaStream_t stream, int ld1, int ld2, int ld_out) {
math_t *out, bool is_row_major, cudaStream_t stream, int ld1,
int ld2, int ld_out) {
math_t alpha = 1.0;
math_t beta = 0.0;
CUBLAS_CHECK(raft::linalg::cublasgemm(
cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, n1, n2, n_cols, &alpha, x1, ld1,
x2, ld2, &beta, out, ld_out, stream));
if (is_row_major) {
CUBLAS_CHECK(raft::linalg::cublasgemm(
cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, n2, n1, n_cols, &alpha, x2,
ld2, x1, ld1, &beta, out, ld_out, stream));
} else {
CUBLAS_CHECK(raft::linalg::cublasgemm(
cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, n1, n2, n_cols, &alpha, x1,
ld1, x2, ld2, &beta, out, ld_out, stream));
}
}

/** Calculates the Gram matrix using Euclidean distance.
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] x1 device array of vectors in column major format,
* size [n1*n_cols]
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors in column major format,
* size [n2*n_cols]
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix in column major
* format, size [n1*n2]
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
virtual void distance(const math_t *x1, int n1, int n_cols, const math_t *x2,
int n2, math_t *out, cudaStream_t stream, int ld1,
int ld2, int ld_out) {
int n2, math_t *out, bool is_row_major,
cudaStream_t stream, int ld1, int ld2, int ld_out) {
typedef cutlass::Shape<8, 128, 128> OutputTile_t;
auto fin_op = [] __device__(math_t d_val, int idx) { return d_val; };
Distance::distance<raft::distance::DistanceType::L2Unexpanded, math_t,
math_t, math_t, OutputTile_t>(
x1, x2, out, n1, n2, n_cols, NULL, 0, fin_op, stream, false);
x1, x2, out, n1, n2, n_cols, NULL, 0, fin_op, stream, is_row_major);
}
};
}; // end namespace Matrix
Expand Down
Loading

0 comments on commit b10f31f

Please sign in to comment.