Skip to content

Commit

Permalink
Fix for gemmi deprecation (rapidsai#1020)
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener authored Nov 16, 2022
1 parent 355f693 commit e9d0944
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 17 deletions.
2 changes: 2 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ CMAKE_LOG_LEVEL=""
VERBOSE_FLAG=""
BUILD_ALL_GPU_ARCH=0
BUILD_TESTS=OFF
BUILD_TYPE=Release
BUILD_BENCH=OFF
BUILD_STATIC_FAISS=OFF
COMPILE_LIBRARIES=OFF
Expand Down Expand Up @@ -336,6 +337,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has
cmake -S ${REPODIR}/cpp -B ${LIBRAFT_BUILD_DIR} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \
-DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \
-DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \
-DRAFT_NVTX=${NVTX} \
Expand Down
99 changes: 88 additions & 11 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cusparse.h>
#include <raft/core/cusparse_macros.hpp>
#include <raft/core/error.hpp>
#include <raft/linalg/transpose.cuh>
#include <rmm/device_uvector.hpp>

namespace raft {
Expand Down Expand Up @@ -650,6 +651,73 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle,
* @defgroup Gemmi cusparse gemmi operations
* @{
*/
#if CUDART_VERSION < 12000
template <typename T>
cusparseStatus_t cusparsegemmi( // NOLINT
cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const T* alpha,
const T* A,
int lda,
const T* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const T* beta,
T* C,
int ldc,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const float* alpha,
const float* A,
int lda,
const float* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const float* beta,
float* C,
int ldc,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return cusparseSgemmi(
handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc);
#pragma GCC diagnostic pop
}
template <>
inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle,
int m,
int n,
int k,
int nnz,
const double* alpha,
const double* A,
int lda,
const double* cscValB,
const int* cscColPtrB,
const int* cscRowIndB,
const double* beta,
double* C,
int ldc,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return cusparseDgemmi(
handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc);
#pragma GCC diagnostic pop
}
#else // CUDART >= 12.0
template <typename T>
cusparseStatus_t cusparsegemmi( // NOLINT
cusparseHandle_t handle,
Expand All @@ -673,8 +741,9 @@ cusparseStatus_t cusparsegemmi( // NOLINT
cusparseDnMatDescr_t matA;
cusparseSpMatDescr_t matB;
cusparseDnMatDescr_t matC;
rmm::device_uvector<T> CT(m * n, stream);

auto math_type = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
auto constexpr math_type = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
// Create sparse matrix B
CUSPARSE_CHECK(cusparseCreateCsc(&matB,
k,
Expand All @@ -687,30 +756,38 @@ cusparseStatus_t cusparsegemmi( // NOLINT
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
math_type));
// Create dense matrices
/**
* Create dense matrices.
* Note: Since this is replacing `cusparse_gemmi`, it assumes dense inputs are
* column-ordered
*/
CUSPARSE_CHECK(cusparseCreateDnMat(
&matA, m, k, lda, static_cast<void*>(const_cast<T*>(A)), math_type, CUSPARSE_ORDER_ROW));
&matA, m, k, lda, static_cast<void*>(const_cast<T*>(A)), math_type, CUSPARSE_ORDER_COL));
CUSPARSE_CHECK(cusparseCreateDnMat(
&matC, m, n, ldc, static_cast<void*>(const_cast<T*>(C)), math_type, CUSPARSE_ORDER_ROW));
&matC, n, m, n, static_cast<void*>(CT.data()), math_type, CUSPARSE_ORDER_COL));

cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE;
cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE;
cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2;
size_t buffer_size = 0;
auto opA = CUSPARSE_OPERATION_TRANSPOSE;
auto opB = CUSPARSE_OPERATION_TRANSPOSE;
auto alg = CUSPARSE_SPMM_CSR_ALG1;
auto buffer_size = std::size_t{};

CUSPARSE_CHECK(cusparsespmm_bufferSize(
handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream));
handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream));
buffer_size = buffer_size / sizeof(T);
rmm::device_uvector<T> external_buffer(buffer_size, stream);
auto return_value = cusparsespmm(
handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream);
auto ext_buf = static_cast<T*>(static_cast<void*>(external_buffer.data()));
auto return_value =
cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream);

raft::handle_t rhandle;
raft::linalg::transpose(rhandle, CT.data(), C, n, m, stream);
// destroy matrix/vector descriptors
CUSPARSE_CHECK(cusparseDestroyDnMat(matA));
CUSPARSE_CHECK(cusparseDestroySpMat(matB));
CUSPARSE_CHECK(cusparseDestroyDnMat(matC));
return return_value;
}
#endif
/** @} */

/**
Expand Down
6 changes: 1 addition & 5 deletions cpp/include/raft/sparse/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __TRANSPOSE_H
#define __TRANSPOSE_H

#pragma once

Expand Down Expand Up @@ -69,6 +67,4 @@ void csr_transpose(const raft::handle_t& handle,

}; // end NAMESPACE linalg
}; // end NAMESPACE sparse
}; // end NAMESPACE raft

#endif
}; // end NAMESPACE raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ if(BUILD_TESTS)
test/sparse/reduce.cu
test/sparse/row_op.cu
test/sparse/sort.cu
test/sparse/spgemmi.cu
test/sparse/symmetrize.cu
)

Expand Down
1 change: 0 additions & 1 deletion cpp/test/sparse/csr_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ namespace raft {
namespace sparse {

using namespace raft;
using namespace raft::sparse;

template <typename value_idx, typename value_t>
struct CSRTransposeInputs {
Expand Down
142 changes: 142 additions & 0 deletions cpp/test/sparse/spgemmi.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#include "../test_utils.h"

#include <raft/core/handle.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/sparse/detail/cusparse_wrappers.h>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>

#include <iostream>
#include <limits>

namespace raft {
namespace sparse {

struct SPGemmiInputs {
int n_rows, n_cols;
};

template <typename data_t>
class SPGemmiTest : public ::testing::TestWithParam<SPGemmiInputs> {
public:
SPGemmiTest()
: params(::testing::TestWithParam<SPGemmiInputs>::GetParam()), stream(handle.get_stream())
{
}

protected:
void SetUp() override {}

void Run()
{
// Host problem definition
float alpha = 1.0f;
float beta = 0.0f;
int A_num_rows = 5;
int A_num_cols = 3;
// int B_num_rows = A_num_cols;
int B_num_cols = 4;
int B_nnz = 9;
int lda = A_num_rows;
int ldc = A_num_rows;
int A_size = lda * A_num_cols;
int C_size = ldc * B_num_cols;
int hB_cscOffsets[] = {0, 3, 4, 7, 9};
int hB_rows[] = {0, 2, 3, 1, 0, 2, 3, 1, 3};
float hB_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
float hA[] = {1.0f,
2.0f,
3.0f,
4.0f,
5.0f,
6.0f,
7.0f,
8.0f,
9.0f,
10.0f,
11.0f,
12.0f,
13.0f,
14.0f,
15.0f};
std::vector<float> hC(C_size);
std::vector<float> hC_expected{23, 26, 29, 32, 35, 24, 28, 32, 36, 40,
71, 82, 93, 104, 115, 48, 56, 64, 72, 80};
//--------------------------------------------------------------------------
// Device memory management
rmm::device_uvector<int> dB_cscOffsets(B_num_cols + 1, stream);
rmm::device_uvector<int> dB_rows(B_nnz, stream);
rmm::device_uvector<float> dB_values(B_nnz, stream);
rmm::device_uvector<float> dA(A_size, stream);
rmm::device_uvector<float> dC(C_size, stream);
rmm::device_uvector<float> dCT(C_size, stream);

raft::update_device(dB_cscOffsets.data(), hB_cscOffsets, B_num_cols + 1, stream);
raft::update_device(dB_rows.data(), hB_rows, B_nnz, stream);
raft::update_device(dB_values.data(), hB_values, B_nnz, stream);
raft::update_device(dA.data(), hA, A_size, stream);
raft::update_device(dC.data(), hC.data(), C_size, stream);

//--------------------------------------------------------------------------
// execute gemmi
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsegemmi(handle.get_cusparse_handle(),
A_num_rows,
B_num_cols,
A_num_cols,
B_nnz,
&alpha,
dA.data(),
lda,
dB_values.data(),
dB_cscOffsets.data(),
dB_rows.data(),
&beta,
dC.data(),
ldc,
handle.get_stream()));

//--------------------------------------------------------------------------
// result check
raft::update_host(hC.data(), dC.data(), C_size, stream);
ASSERT_TRUE(hostVecMatch(hC_expected, hC, raft::Compare<float>()));
}

protected:
raft::handle_t handle;
cudaStream_t stream;

SPGemmiInputs params;
};

using SPGemmiTestF = SPGemmiTest<float>;
TEST_P(SPGemmiTestF, Result) { Run(); }

using SPGemmiTestD = SPGemmiTest<double>;
TEST_P(SPGemmiTestD, Result) { Run(); }

const std::vector<SPGemmiInputs> csc_inputs_f = {{5, 4}};
const std::vector<SPGemmiInputs> csc_inputs_d = {{5, 4}};

INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestF, ::testing::ValuesIn(csc_inputs_f));
INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestD, ::testing::ValuesIn(csc_inputs_d));

} // namespace sparse
} // namespace raft

0 comments on commit e9d0944

Please sign in to comment.