From 0908af75efa9287d99125db02afdb96044ea94c4 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 8 Feb 2022 14:33:51 +0100 Subject: [PATCH] Fix badly merged cublas wrappers --- cpp/include/raft/linalg/{axpy.h => axpy.hpp} | 8 +--- cpp/include/raft/linalg/detail/axpy.hpp | 43 ++++++++++++++++++++ cpp/include/raft/linalg/detail/gemm.hpp | 4 +- cpp/include/raft/linalg/detail/gemv.hpp | 37 +++++++++++++++-- cpp/include/raft/linalg/gemm.hpp | 43 ++++++++++++++++++++ cpp/include/raft/linalg/gemv.hpp | 22 ++-------- 6 files changed, 127 insertions(+), 30 deletions(-) rename cpp/include/raft/linalg/{axpy.h => axpy.hpp} (84%) create mode 100644 cpp/include/raft/linalg/detail/axpy.hpp diff --git a/cpp/include/raft/linalg/axpy.h b/cpp/include/raft/linalg/axpy.hpp similarity index 84% rename from cpp/include/raft/linalg/axpy.h rename to cpp/include/raft/linalg/axpy.hpp index 27b14aea08..5a5a873132 100644 --- a/cpp/include/raft/linalg/axpy.h +++ b/cpp/include/raft/linalg/axpy.hpp @@ -16,9 +16,7 @@ #pragma once -#include -#include -#include +#include "detail/axpy.hpp" namespace raft::linalg { @@ -47,9 +45,7 @@ void axpy(const raft::handle_t& handle, const int incy, cudaStream_t stream) { - auto cublas_h = handle.get_cublas_handle(); - cublas_device_pointer_mode pmode(cublas_h); - RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); + detail::axpy(handle, n, alpha, x, incx, y, incy, stream); } } // namespace raft::linalg diff --git a/cpp/include/raft/linalg/detail/axpy.hpp b/cpp/include/raft/linalg/detail/axpy.hpp new file mode 100644 index 0000000000..f5527bf10f --- /dev/null +++ b/cpp/include/raft/linalg/detail/axpy.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "cublas_wrappers.hpp" + +#include +#include + +namespace raft::linalg::detail { + +template +void axpy(const raft::handle_t& handle, + const int n, + const T* alpha, + const T* x, + const int incx, + T* y, + const int incy, + cudaStream_t stream) +{ + auto cublas_h = handle.get_cublas_handle(); + cublas_device_pointer_mode pmode(cublas_h); + RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); +} + +} // namespace raft::linalg::detail diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index 8a02b702e5..0ea1723a9e 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -16,8 +16,10 @@ #pragma once -#include "cublas_wrappers.hpp" #include + +#include "cublas_wrappers.hpp" + #include #include diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index 991268cf26..3692743152 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -27,6 +27,38 @@ namespace raft { namespace linalg { namespace detail { +template +void gemv(const raft::handle_t& handle, + const bool trans_a, + const int m, + const int n, + const math_t* alpha, + const math_t* A, + const int lda, + const math_t* x, + const int incx, + const math_t* beta, + math_t* y, + const int incy, + cudaStream_t stream) +{ + cublasHandle_t cublas_h = handle.get_cublas_handle(); + detail::cublas_device_pointer_mode pmode(cublas_h); + RAFT_CUBLAS_TRY(detail::cublasgemv(cublas_h, + trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, + m, + n, + alpha, + A, + lda, + x, + incx, + beta, + y, + incy, + stream)); +} + template void gemv(const raft::handle_t& handle, const math_t* A, @@ -41,10 +73,7 @@ void gemv(const raft::handle_t& handle, const math_t beta, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); - cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - RAFT_CUBLAS_TRY( - cublasgemv(cublas_h, op_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream)); + gemv(handle, trans_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream); } template diff --git a/cpp/include/raft/linalg/gemm.hpp b/cpp/include/raft/linalg/gemm.hpp index 04ddbb3561..f22d15e650 100644 --- a/cpp/include/raft/linalg/gemm.hpp +++ b/cpp/include/raft/linalg/gemm.hpp @@ -21,6 +21,49 @@ namespace raft { namespace linalg { +/** + * @brief the wrapper of cublas gemm function + * It computes the following equation: C = alpha .* opA(A) * opB(B) + beta .* C + * + * @tparam math_t the element type + * @tparam DevicePointerMode whether pointers alpha, beta point to device memory + * @param [in] handle raft handle + * @param [in] trans_a cublas transpose op for A + * @param [in] trans_b cublas transpose op for B + * @param [in] m number of rows of C + * @param [in] n number of columns of C + * @param [in] k number of rows of opB(B) / number of columns of opA(A) + * @param [in] alpha host or device scalar + * @param [in] A such a matrix that the shape of column-major opA(A) is [m, k] + * @param [in] lda leading dimension of A + * @param [in] B such a matrix that the shape of column-major opA(B) is [k, n] + * @param [in] ldb leading dimension of B + * @param [in] beta host or device scalar + * @param [inout] C column-major matrix of size [m, n] + * @param [in] ldc leading dimension of C + * @param [in] stream + */ +template +void gemm(const raft::handle_t& handle, + const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + const math_t* alpha, + const math_t* A, + const int lda, + const math_t* B, + const int ldb, + const math_t* beta, + const math_t* C, + const int ldc, + cudaStream_t stream) +{ + detail::gemm( + handle, trans_a, trans_b, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, stream); +} + /** * @brief the wrapper of cublas gemm function * It computes the following equation: D = alpha . opA(A) * opB(B) + beta . C diff --git a/cpp/include/raft/linalg/gemv.hpp b/cpp/include/raft/linalg/gemv.hpp index 45766b8c9a..2098027b16 100644 --- a/cpp/include/raft/linalg/gemv.hpp +++ b/cpp/include/raft/linalg/gemv.hpp @@ -17,9 +17,6 @@ #pragma once #include "detail/gemv.hpp" -#include - -#include namespace raft { namespace linalg { @@ -59,21 +56,8 @@ void gemv(const raft::handle_t& handle, const int incy, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); - detail::cublas_device_pointer_mode pmode(cublas_h); - RAFT_CUBLAS_TRY(detail::cublasgemv(cublas_h, - trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, - m, - n, - alpha, - A, - lda, - x, - incx, - beta, - y, - incy, - stream)); + detail::gemv( + handle, trans_a, m, n, alpha, A, lda, x, incx, beta, y, incy, stream); } template @@ -90,7 +74,7 @@ void gemv(const raft::handle_t& handle, const math_t beta, cudaStream_t stream) { - gemv(handle, trans_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream); + detail::gemv(handle, A, n_rows, n_cols, x, incx, y, incy, trans_a, alpha, beta, stream); } /**