From 1d4128602718bdbfc4e81e039ba989fa23c90cfa Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 20 Jan 2022 09:16:09 +0100 Subject: [PATCH 1/2] CUBLAS wrappers with switchable host/device pointer mode --- cpp/include/raft/linalg/axpy.h | 40 +++++++++++++++++++++++ cpp/include/raft/linalg/cublas_wrappers.h | 31 +++++++++++++++++- cpp/include/raft/linalg/gemm.cuh | 38 ++++++++++++++++++++- cpp/include/raft/linalg/gemv.h | 39 +++++++++++++++++++--- 4 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 cpp/include/raft/linalg/axpy.h diff --git a/cpp/include/raft/linalg/axpy.h b/cpp/include/raft/linalg/axpy.h new file mode 100644 index 0000000000..4b4fedbbe0 --- /dev/null +++ b/cpp/include/raft/linalg/axpy.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::linalg { + +template +void axpy(const raft::handle_t& handle, + const int n, + const T* alpha, + const T* x, + const int incx, + T* y, + const int incy, + cudaStream_t stream) +{ + auto cublas_h = handle.get_cublas_handle(); + cublas_device_pointer_mode pmode(cublas_h); + RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); +} + +} // namespace raft::linalg diff --git a/cpp/include/raft/linalg/cublas_wrappers.h b/cpp/include/raft/linalg/cublas_wrappers.h index 024ed4a0e2..246e6466d8 100644 --- a/cpp/include/raft/linalg/cublas_wrappers.h +++ b/cpp/include/raft/linalg/cublas_wrappers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,6 +118,35 @@ inline const char* cublas_error_to_string(cublasStatus_t err) namespace raft { namespace linalg { +/** + * Assuming the default CUBLAS_POINTER_MODE_HOST, change it to host or device mode + * temporary for the lifetime of this object. + */ +template +class cublas_device_pointer_mode { + public: + explicit cublas_device_pointer_mode(cublasHandle_t handle) : handle_(handle) + { + if constexpr (DevicePointerMode) { + RAFT_CUBLAS_TRY(cublasSetPointerMode(handle_, CUBLAS_POINTER_MODE_DEVICE)); + } + } + auto operator=(const cublas_device_pointer_mode&) -> cublas_device_pointer_mode& = delete; + auto operator=(cublas_device_pointer_mode&&) -> cublas_device_pointer_mode& = delete; + static auto operator new(std::size_t) -> void* = delete; + static auto operator new[](std::size_t) -> void* = delete; + + ~cublas_device_pointer_mode() + { + if constexpr (DevicePointerMode) { + RAFT_CUBLAS_TRY_NO_THROW(cublasSetPointerMode(handle_, CUBLAS_POINTER_MODE_HOST)); + } + } + + private: + cublasHandle_t handle_ = nullptr; +}; + /** * @defgroup Axpy cublas ax+y operations * @{ diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index 9aff35619e..dcbd1c3c28 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,42 @@ namespace raft { namespace linalg { +template +void gemm(const raft::handle_t& handle, + const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + const math_t* alpha, + const math_t* A, + const int lda, + const math_t* B, + const int ldb, + const math_t* beta, + const math_t* C, + const int ldc, + cudaStream_t stream) +{ + cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublas_device_pointer_mode pmode(cublas_h); + RAFT_CUBLAS_TRY(cublasgemm(cublas_h, + trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, + trans_b ? CUBLAS_OP_T : CUBLAS_OP_N, + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc, + stream)); +} + /** * @brief the wrapper of cublas gemm function * It computes the following equation: D = alpha . opA(A) * opB(B) + beta . C diff --git a/cpp/include/raft/linalg/gemv.h b/cpp/include/raft/linalg/gemv.h index 462107df65..767a5bd4b0 100644 --- a/cpp/include/raft/linalg/gemv.h +++ b/cpp/include/raft/linalg/gemv.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,38 @@ namespace raft { namespace linalg { +template +void gemv(const raft::handle_t& handle, + const bool trans_a, + const int m, + const int n, + const math_t* alpha, + const math_t* A, + const int lda, + const math_t* x, + const int incx, + const math_t* beta, + math_t* y, + const int incy, + cudaStream_t stream) +{ + cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublas_device_pointer_mode pmode(cublas_h); + RAFT_CUBLAS_TRY(cublasgemv(cublas_h, + trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, + m, + n, + alpha, + A, + lda, + x, + incx, + beta, + y, + incy, + stream)); +} + template void gemv(const raft::handle_t& handle, const math_t* A, @@ -39,10 +71,7 @@ void gemv(const raft::handle_t& handle, const math_t beta, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); - cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - RAFT_CUBLAS_TRY( - cublasgemv(cublas_h, op_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream)); + gemv(handle, trans_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream); } /** From a65242956d24dd27db1eeb2a46929a6930e4cf2c Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 3 Feb 2022 16:35:53 +0100 Subject: [PATCH 2/2] Add docs --- cpp/include/raft/linalg/axpy.h | 15 +++++++++++++++ cpp/include/raft/linalg/gemm.cuh | 22 ++++++++++++++++++++++ cpp/include/raft/linalg/gemv.h | 20 ++++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/cpp/include/raft/linalg/axpy.h b/cpp/include/raft/linalg/axpy.h index 4b4fedbbe0..27b14aea08 100644 --- a/cpp/include/raft/linalg/axpy.h +++ b/cpp/include/raft/linalg/axpy.h @@ -22,6 +22,21 @@ namespace raft::linalg { +/** + * @brief the wrapper of cublas axpy function + * It computes the following equation: y = alpha * x + y + * + * @tparam T the element type + * @tparam DevicePointerMode whether pointers alpha, beta point to device memory + * @param [in] handle raft handle + * @param [in] n number of elements in x and y + * @param [in] alpha host or device scalar + * @param [in] x vector of length n + * @param [in] incx stride between consecutive elements of x + * @param [inout] y vector of length n + * @param [in] incy stride between consecutive elements of y + * @param [in] stream + */ template void axpy(const raft::handle_t& handle, const int n, diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index dcbd1c3c28..b5147915ba 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -24,6 +24,28 @@ namespace raft { namespace linalg { +/** + * @brief the wrapper of cublas gemm function + * It computes the following equation: C = alpha .* opA(A) * opB(B) + beta .* C + * + * @tparam math_t the element type + * @tparam DevicePointerMode whether pointers alpha, beta point to device memory + * @param [in] handle raft handle + * @param [in] trans_a cublas transpose op for A + * @param [in] trans_b cublas transpose op for B + * @param [in] m number of rows of C + * @param [in] n number of columns of C + * @param [in] k number of rows of opB(B) / number of columns of opA(A) + * @param [in] alpha host or device scalar + * @param [in] A such a matrix that the shape of column-major opA(A) is [m, k] + * @param [in] lda leading dimension of A + * @param [in] B such a matrix that the shape of column-major opA(B) is [k, n] + * @param [in] ldb leading dimension of B + * @param [in] beta host or device scalar + * @param [inout] C column-major matrix of size [m, n] + * @param [in] ldc leading dimension of C + * @param [in] stream + */ template void gemm(const raft::handle_t& handle, const bool trans_a, diff --git a/cpp/include/raft/linalg/gemv.h b/cpp/include/raft/linalg/gemv.h index 767a5bd4b0..9eafb3941a 100644 --- a/cpp/include/raft/linalg/gemv.h +++ b/cpp/include/raft/linalg/gemv.h @@ -25,6 +25,26 @@ namespace raft { namespace linalg { +/** + * @brief the wrapper of cublas gemv function + * It computes the following equation: y = alpha .* op(A) * x + beta .* y + * + * @tparam math_t the element type + * @tparam DevicePointerMode whether pointers alpha, beta point to device memory + * @param [in] handle raft handle + * @param [in] trans_a cublas transpose op for A + * @param [in] m number of rows of A + * @param [in] n number of columns of A + * @param [in] alpha host or device scalar + * @param [in] A column-major matrix of size [m, n] + * @param [in] lda leading dimension of A + * @param [in] x vector of length n if trans_a else m + * @param [in] incx stride between consecutive elements of x + * @param [in] beta host or device scalar + * @param [inout] y vector of length m if trans_a else n + * @param [in] incy stride between consecutive elements of y + * @param [in] stream + */ template void gemv(const raft::handle_t& handle, const bool trans_a,