Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUBLAS wrappers with switchable host/device pointer mode #453

Merged
merged 4 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions cpp/include/raft/linalg/axpy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cuda_utils.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/cublas_wrappers.h>

namespace raft::linalg {

template <typename T, bool DevicePointerMode = false>
void axpy(const raft::handle_t& handle,
const int n,
const T* alpha,
achirkin marked this conversation as resolved.
Show resolved Hide resolved
const T* x,
const int incx,
T* y,
const int incy,
cudaStream_t stream)
{
auto cublas_h = handle.get_cublas_handle();
cublas_device_pointer_mode<DevicePointerMode> pmode(cublas_h);
RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream));
}

} // namespace raft::linalg
31 changes: 30 additions & 1 deletion cpp/include/raft/linalg/cublas_wrappers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -118,6 +118,35 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
namespace raft {
namespace linalg {

/**
* Assuming the default CUBLAS_POINTER_MODE_HOST, change it to host or device mode
* temporary for the lifetime of this object.
*/
template <bool DevicePointerMode = false>
class cublas_device_pointer_mode {
public:
explicit cublas_device_pointer_mode(cublasHandle_t handle) : handle_(handle)
{
if constexpr (DevicePointerMode) {
RAFT_CUBLAS_TRY(cublasSetPointerMode(handle_, CUBLAS_POINTER_MODE_DEVICE));
}
}
auto operator=(const cublas_device_pointer_mode&) -> cublas_device_pointer_mode& = delete;
auto operator=(cublas_device_pointer_mode&&) -> cublas_device_pointer_mode& = delete;
static auto operator new(std::size_t) -> void* = delete;
static auto operator new[](std::size_t) -> void* = delete;

~cublas_device_pointer_mode()
{
if constexpr (DevicePointerMode) {
RAFT_CUBLAS_TRY_NO_THROW(cublasSetPointerMode(handle_, CUBLAS_POINTER_MODE_HOST));
}
}

private:
cublasHandle_t handle_ = nullptr;
};

/**
* @defgroup Axpy cublas ax+y operations
* @{
Expand Down
38 changes: 37 additions & 1 deletion cpp/include/raft/linalg/gemm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,42 @@
namespace raft {
namespace linalg {

template <typename math_t, bool DevicePointerMode = false>
achirkin marked this conversation as resolved.
Show resolved Hide resolved
void gemm(const raft::handle_t& handle,
const bool trans_a,
const bool trans_b,
const int m,
const int n,
const int k,
const math_t* alpha,
const math_t* A,
const int lda,
const math_t* B,
const int ldb,
const math_t* beta,
const math_t* C,
const int ldc,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
cublas_device_pointer_mode<DevicePointerMode> pmode(cublas_h);
RAFT_CUBLAS_TRY(cublasgemm(cublas_h,
trans_a ? CUBLAS_OP_T : CUBLAS_OP_N,
trans_b ? CUBLAS_OP_T : CUBLAS_OP_N,
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
stream));
}

/**
* @brief the wrapper of cublas gemm function
* It computes the following equation: D = alpha . opA(A) * opB(B) + beta . C
Expand Down
39 changes: 34 additions & 5 deletions cpp/include/raft/linalg/gemv.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,38 @@
namespace raft {
namespace linalg {

template <typename math_t, bool DevicePointerMode = false>
void gemv(const raft::handle_t& handle,
const bool trans_a,
achirkin marked this conversation as resolved.
Show resolved Hide resolved
const int m,
const int n,
const math_t* alpha,
const math_t* A,
const int lda,
const math_t* x,
const int incx,
const math_t* beta,
math_t* y,
const int incy,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
cublas_device_pointer_mode<DevicePointerMode> pmode(cublas_h);
RAFT_CUBLAS_TRY(cublasgemv(cublas_h,
trans_a ? CUBLAS_OP_T : CUBLAS_OP_N,
m,
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
stream));
}

template <typename math_t>
void gemv(const raft::handle_t& handle,
const math_t* A,
Expand All @@ -39,10 +71,7 @@ void gemv(const raft::handle_t& handle,
const math_t beta,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
RAFT_CUBLAS_TRY(
cublasgemv(cublas_h, op_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream));
gemv(handle, trans_a, n_rows, n_cols, &alpha, A, n_rows, x, incx, &beta, y, incy, stream);
}

/**
Expand Down