Skip to content

Commit

Permalink
Moving kernel gramm primitives to raft::distance::kernels (#920)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #920
  • Loading branch information
cjnolet authored Oct 18, 2022
1 parent 5f56933 commit b4939b7
Show file tree
Hide file tree
Showing 24 changed files with 1,747 additions and 1 deletion.
10 changes: 10 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,21 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/hamming_unexpanded.cu
src/distance/specializations/detail/hellinger_expanded.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ if(BUILD_BENCH)
bench/distance/distance_exp_l2.cu
bench/distance/distance_l1.cu
bench/distance/distance_unexp_l2.cu
bench/distance/kernels.cu
bench/main.cpp
OPTIONAL DIST
)
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/distance/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include <common/benchmark.hpp>
#include <raft/cudart_utils.h>
#include <raft/distance/distance.cuh>
#include <raft/util/cudart_utils.hpp>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#endif
Expand Down
123 changes: 123 additions & 0 deletions cpp/bench/distance/kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#endif

#include <common/benchmark.hpp>
#include <memory>
#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/random/rng.cuh>
#include <sstream>
#include <string>
#include <vector>

namespace raft::bench::distance::kernels {

using namespace raft::distance::kernels;
struct GramTestParams {
int m; // m parameter of the GEMM
int k; // k parameter of the GEMM
int n; // n parameter of the GEMM
KernelParams kernel_params;
bool is_row_major;
}; // struct GramTestParams

template <typename T>
struct GramMatrix : public fixture {
GramMatrix(const GramTestParams& p)
: params(p), handle(stream), A(0, stream), B(0, stream), C(0, stream)
{
kernel = std::unique_ptr<GramMatrixBase<T>>(
KernelFactory<T>::create(p.kernel_params, handle.get_cublas_handle()));

A.resize(params.m * params.k, stream);
B.resize(params.k * params.n, stream);
C.resize(params.m * params.n, stream);
raft::random::Rng r(123456ULL);
r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream);
}

~GramMatrix()
{
A.release();
B.release();
C.release();
}

void run_benchmark(::benchmark::State& state) override
{
if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); }
loop_on_state(state, [this]() {
(*this->kernel)(A.data(),
this->params.m,
this->params.k,
B.data(),
this->params.n,
C.data(),
this->params.is_row_major,
this->stream);
});
}

private:
const raft::handle_t handle;
std::unique_ptr<GramMatrixBase<T>> kernel;
GramTestParams params;

rmm::device_uvector<T> A; // input matrix A, size [m * k]
rmm::device_uvector<T> B; // input matrix B, size [n * k]
rmm::device_uvector<T> C; // output matrix C, size [m*n]
};

static std::vector<GramTestParams> getInputs()
{
std::vector<GramTestParams> param_vec;
std::vector<KernelParams> kernel_params{KernelParams{LINEAR, 3, 1, 0},
KernelParams{POLYNOMIAL, 2, 1.3, 1},
KernelParams{TANH, 2, 0.5, 2.4},
KernelParams{RBF, 2, 0.5, 0}};
struct TestSize {
int m;
int k;
int n;
};
std::vector<TestSize> data_size{{4096, 10, 1024},
{4096, 100, 1024},
{4096, 1000, 1024},
{4096, 10000, 1024},
{100000, 10, 1024},
{100000, 100, 1024},
{100000, 1000, 1024}};

param_vec.reserve(kernel_params.size() * data_size.size());
for (TestSize s : data_size) {
for (auto kernel : kernel_params) {
for (bool row_major : {false, true}) {
param_vec.push_back(GramTestParams{s.m, s.k, s.n, kernel, row_major});
}
}
}
return param_vec;
}

RAFT_BENCH_REGISTER(GramMatrix<float>, "", getInputs());
RAFT_BENCH_REGISTER(GramMatrix<double>, "", getInputs());

} // namespace raft::bench::distance::kernels
218 changes: 218 additions & 0 deletions cpp/include/raft/distance/detail/kernels/gram_matrix.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>

#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/gemm.cuh>

namespace raft::distance::kernels::detail {

/**
* Base class for general Gram matrices
* A Gram matrix is the Hermitian matrix of inner probucts G_ik = <x_i, x_k>
* Here, the inner product is evaluated for all elements from vectors sets X1,
* and X2.
*
* To be more precise, on exit the output buffer will store:
* - if is_row_major == true: out[j+k*n1] = <x1_j, x2_k>,
* - if is_row_major == false: out[j*n2 + k] = <x1_j, x2_k>,
* where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector
* from the x2 set.
*/
template <typename math_t>
class GramMatrixBase {
cublasHandle_t cublas_handle;

public:
GramMatrixBase(cublasHandle_t cublas_handle) : cublas_handle(cublas_handle){};

virtual ~GramMatrixBase(){};

/** Convenience function to evaluate the Gram matrix for two vector sets.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
virtual void operator()(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1 = 0,
int ld2 = 0,
int ld_out = 0)
{
if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; }
if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; }
if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; }
evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out);
}

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
*/
virtual void evaluate(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1,
int ld2,
int ld_out)
{
linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out);
}

// private:
// The following methods should be private, they are kept public to avoid:
// "error: The enclosing parent function ("distance") for an extended
// __device__ lambda cannot have private or protected access within its class"

/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of colums (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
void linear(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1,
int ld2,
int ld_out)
{
math_t alpha = 1.0;
math_t beta = 0.0;
if (is_row_major) {
// #TODO: Call from public API when ready
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n2,
n1,
n_cols,
&alpha,
x2,
ld2,
x1,
ld1,
&beta,
out,
ld_out,
stream));
} else {
// #TODO: Call from public API when ready
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
n1,
n2,
n_cols,
&alpha,
x1,
ld1,
x2,
ld2,
&beta,
out,
ld_out,
stream));
}
}

/** Calculates the Gram matrix using Euclidean distance.
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
virtual void distance(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1,
int ld2,
int ld_out)
{
raft::distance::distance<raft::distance::DistanceType::L2Unexpanded, math_t, math_t, math_t>(
x1, x2, out, n1, n2, n_cols, stream, is_row_major);
}
};
}; // end namespace raft::distance::kernels::detail
Loading

0 comments on commit b4939b7

Please sign in to comment.