Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix for gemmi deprecation #1020

Merged
merged 9 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ CMAKE_LOG_LEVEL=""
VERBOSE_FLAG=""
BUILD_ALL_GPU_ARCH=0
BUILD_TESTS=OFF
BUILD_TYPE=Release
BUILD_BENCH=OFF
BUILD_STATIC_FAISS=OFF
COMPILE_LIBRARIES=OFF
Expand Down Expand Up @@ -336,6 +337,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has
cmake -S ${REPODIR}/cpp -B ${LIBRAFT_BUILD_DIR} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \
-DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \
-DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \
-DRAFT_NVTX=${NVTX} \
Expand Down
99 changes: 88 additions & 11 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cusparse.h>
#include <raft/core/cusparse_macros.hpp>
#include <raft/core/error.hpp>
#include <raft/linalg/transpose.cuh>
#include <rmm/device_uvector.hpp>

namespace raft {
Expand Down Expand Up @@ -650,6 +651,73 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle,
* @defgroup Gemmi cusparse gemmi operations
* @{
*/
#if CUDART_VERSION < 12000
template <typename T>
cusparseStatus_t cusparsegemmi( // NOLINT
cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const T* alpha,
const T* A,
int lda,
const T* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const T* beta,
T* C,
int ldc,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const float* alpha,
const float* A,
int lda,
const float* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const float* beta,
float* C,
int ldc,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return cusparseSgemmi(
handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc);
#pragma GCC diagnostic pop
}
template <>
inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const double* alpha,
const double* A,
int lda,
const double* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const double* beta,
double* C,
int ldc,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return cusparseDgemmi(
handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc);
#pragma GCC diagnostic pop
}
#else // CUDART >= 12.0
template <typename T>
cusparseStatus_t cusparsegemmi( // NOLINT
cusparseHandle_t handle,
Expand All @@ -673,8 +741,9 @@ cusparseStatus_t cusparsegemmi( // NOLINT
cusparseDnMatDescr_t matA;
cusparseSpMatDescr_t matB;
cusparseDnMatDescr_t matC;
rmm::device_uvector<T> CT(m * n, stream);

auto math_type = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
auto constexpr math_type = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
// Create sparse matrix B
CUSPARSE_CHECK(cusparseCreateCsc(&matB,
k,
Expand All @@ -687,30 +756,38 @@ cusparseStatus_t cusparsegemmi( // NOLINT
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
math_type));
// Create dense matrices
/**
* Create dense matrices.
* Note: Since this is replacing `cusparse_gemmi`, it assumes dense inputs are
* column-ordered
*/
CUSPARSE_CHECK(cusparseCreateDnMat(
&matA, m, k, lda, static_cast<void*>(const_cast<T*>(A)), math_type, CUSPARSE_ORDER_ROW));
&matA, m, k, lda, static_cast<void*>(const_cast<T*>(A)), math_type, CUSPARSE_ORDER_COL));
CUSPARSE_CHECK(cusparseCreateDnMat(
&matC, m, n, ldc, static_cast<void*>(const_cast<T*>(C)), math_type, CUSPARSE_ORDER_ROW));
&matC, n, m, n, static_cast<void*>(CT.data()), math_type, CUSPARSE_ORDER_COL));

cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE;
cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE;
cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2;
size_t buffer_size = 0;
auto opA = CUSPARSE_OPERATION_TRANSPOSE;
auto opB = CUSPARSE_OPERATION_TRANSPOSE;
auto alg = CUSPARSE_SPMM_CSR_ALG1;
auto buffer_size = std::size_t{};

CUSPARSE_CHECK(cusparsespmm_bufferSize(
handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream));
handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream));
buffer_size = buffer_size / sizeof(T);
rmm::device_uvector<T> external_buffer(buffer_size, stream);
auto return_value = cusparsespmm(
handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream);
auto ext_buf = static_cast<T*>(static_cast<void*>(external_buffer.data()));
auto return_value =
cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream);

raft::handle_t rhandle;
raft::linalg::transpose(rhandle, CT.data(), C, n, m, stream);
// destroy matrix/vector descriptors
CUSPARSE_CHECK(cusparseDestroyDnMat(matA));
CUSPARSE_CHECK(cusparseDestroySpMat(matB));
CUSPARSE_CHECK(cusparseDestroyDnMat(matC));
return return_value;
}
#endif
/** @} */

/**
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ if(BUILD_TESTS)
test/sparse/reduce.cu
test/sparse/row_op.cu
test/sparse/sort.cu
test/sparse/spgemmi.cu
test/sparse/symmetrize.cu
)

Expand Down
142 changes: 142 additions & 0 deletions cpp/test/sparse/spgemmi.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#include "../test_utils.h"

#include <raft/core/handle.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/sparse/detail/cusparse_wrappers.h>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>

#include <iostream>
#include <limits>

namespace raft {
namespace sparse {

struct SPGemmiInputs {
int n_rows, n_cols;
};

template <typename data_t>
class SPGemmiTest : public ::testing::TestWithParam<SPGemmiInputs> {
public:
SPGemmiTest()
: params(::testing::TestWithParam<SPGemmiInputs>::GetParam()), stream(handle.get_stream())
{
}

protected:
void SetUp() override {}

void Run()
{
// Host problem definition
float alpha = 1.0f;
float beta = 0.0f;
int A_num_rows = 5;
int A_num_cols = 3;
// int B_num_rows = A_num_cols;
int B_num_cols = 4;
int B_nnz = 9;
int lda = A_num_rows;
int ldc = A_num_rows;
int A_size = lda * A_num_cols;
int C_size = ldc * B_num_cols;
int hB_cscOffsets[] = {0, 3, 4, 7, 9};
int hB_rows[] = {0, 2, 3, 1, 0, 2, 3, 1, 3};
float hB_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
float hA[] = {1.0f,
2.0f,
3.0f,
4.0f,
5.0f,
6.0f,
7.0f,
8.0f,
9.0f,
10.0f,
11.0f,
12.0f,
13.0f,
14.0f,
15.0f};
std::vector<float> hC(C_size);
std::vector<float> hC_expected{23, 26, 29, 32, 35, 24, 28, 32, 36, 40,
71, 82, 93, 104, 115, 48, 56, 64, 72, 80};
//--------------------------------------------------------------------------
// Device memory management
rmm::device_uvector<int> dB_cscOffsets(B_num_cols + 1, stream);
rmm::device_uvector<int> dB_rows(B_nnz, stream);
rmm::device_uvector<float> dB_values(B_nnz, stream);
rmm::device_uvector<float> dA(A_size, stream);
rmm::device_uvector<float> dC(C_size, stream);
rmm::device_uvector<float> dCT(C_size, stream);

raft::update_device(dB_cscOffsets.data(), hB_cscOffsets, B_num_cols + 1, stream);
raft::update_device(dB_rows.data(), hB_rows, B_nnz, stream);
raft::update_device(dB_values.data(), hB_values, B_nnz, stream);
raft::update_device(dA.data(), hA, A_size, stream);
raft::update_device(dC.data(), hC.data(), C_size, stream);

//--------------------------------------------------------------------------
// execute gemmi
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsegemmi(handle.get_cusparse_handle(),
A_num_rows,
B_num_cols,
A_num_cols,
B_nnz,
&alpha,
dA.data(),
lda,
dB_values.data(),
dB_cscOffsets.data(),
dB_rows.data(),
&beta,
dC.data(),
ldc,
handle.get_stream()));

//--------------------------------------------------------------------------
// result check
raft::update_host(hC.data(), dC.data(), C_size, stream);
ASSERT_TRUE(hostVecMatch(hC_expected, hC, raft::Compare<float>()));
}

protected:
raft::handle_t handle;
cudaStream_t stream;

SPGemmiInputs params;
};

using SPGemmiTestF = SPGemmiTest<float>;
TEST_P(SPGemmiTestF, Result) { Run(); }

using SPGemmiTestD = SPGemmiTest<double>;
TEST_P(SPGemmiTestD, Result) { Run(); }

const std::vector<SPGemmiInputs> csc_inputs_f = {{5, 4}};
const std::vector<SPGemmiInputs> csc_inputs_d = {{5, 4}};

INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestF, ::testing::ValuesIn(csc_inputs_f));
INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestD, ::testing::ValuesIn(csc_inputs_d));

} // namespace sparse
} // namespace raft