From e9d0944abad12d900247ce8da32320447b70559c Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:21:46 +0100 Subject: [PATCH] Fix for gemmi deprecation (#1020) Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1020 --- build.sh | 2 + .../raft/sparse/detail/cusparse_wrappers.h | 99 ++++++++++-- cpp/include/raft/sparse/linalg/transpose.cuh | 6 +- cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/csr_transpose.cu | 1 - cpp/test/sparse/spgemmi.cu | 142 ++++++++++++++++++ 6 files changed, 234 insertions(+), 17 deletions(-) create mode 100644 cpp/test/sparse/spgemmi.cu diff --git a/build.sh b/build.sh index b48465922a..55e1c9a47a 100755 --- a/build.sh +++ b/build.sh @@ -65,6 +65,7 @@ CMAKE_LOG_LEVEL="" VERBOSE_FLAG="" BUILD_ALL_GPU_ARCH=0 BUILD_TESTS=OFF +BUILD_TYPE=Release BUILD_BENCH=OFF BUILD_STATIC_FAISS=OFF COMPILE_LIBRARIES=OFF @@ -336,6 +337,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has cmake -S ${REPODIR}/cpp -B ${LIBRAFT_BUILD_DIR} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ -DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \ -DRAFT_NVTX=${NVTX} \ diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index c8e4229203..3bb2db7902 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace raft { @@ -650,6 +651,73 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, * @defgroup Gemmi cusparse gemmi operations * @{ */ +#if CUDART_VERSION < 12000 +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const float* alpha, + const float* A, + int lda, + const float* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const float* beta, + float* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseSgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const double* alpha, + const double* A, + int lda, + const double* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const double* beta, + double* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +#else // CUDART >= 12.0 template cusparseStatus_t cusparsegemmi( // NOLINT cusparseHandle_t handle, @@ -673,8 +741,9 @@ cusparseStatus_t cusparsegemmi( // NOLINT cusparseDnMatDescr_t matA; cusparseSpMatDescr_t matB; cusparseDnMatDescr_t matC; + rmm::device_uvector CT(m * n, stream); - auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + auto constexpr math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; // Create sparse matrix B CUSPARSE_CHECK(cusparseCreateCsc(&matB, k, @@ -687,30 +756,38 @@ cusparseStatus_t cusparsegemmi( // NOLINT CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, math_type)); - // Create dense matrices + /** + * Create dense matrices. + * Note: Since this is replacing `cusparse_gemmi`, it assumes dense inputs are + * column-ordered + */ CUSPARSE_CHECK(cusparseCreateDnMat( - &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_ROW)); + &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); CUSPARSE_CHECK(cusparseCreateDnMat( - &matC, m, n, ldc, static_cast(const_cast(C)), math_type, CUSPARSE_ORDER_ROW)); + &matC, n, m, n, static_cast(CT.data()), math_type, CUSPARSE_ORDER_COL)); - cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2; - size_t buffer_size = 0; + auto opA = CUSPARSE_OPERATION_TRANSPOSE; + auto opB = CUSPARSE_OPERATION_TRANSPOSE; + auto alg = CUSPARSE_SPMM_CSR_ALG1; + auto buffer_size = std::size_t{}; CUSPARSE_CHECK(cusparsespmm_bufferSize( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); + handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); buffer_size = buffer_size / sizeof(T); rmm::device_uvector external_buffer(buffer_size, stream); - auto return_value = cusparsespmm( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream); + auto ext_buf = static_cast(static_cast(external_buffer.data())); + auto return_value = + cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream); + raft::handle_t rhandle; + raft::linalg::transpose(rhandle, CT.data(), C, n, m, stream); // destroy matrix/vector descriptors CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); CUSPARSE_CHECK(cusparseDestroySpMat(matB)); CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); return return_value; } +#endif /** @} */ /** diff --git a/cpp/include/raft/sparse/linalg/transpose.cuh b/cpp/include/raft/sparse/linalg/transpose.cuh index fa0031aab6..ae527fe34c 100644 --- a/cpp/include/raft/sparse/linalg/transpose.cuh +++ b/cpp/include/raft/sparse/linalg/transpose.cuh @@ -13,8 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __TRANSPOSE_H -#define __TRANSPOSE_H #pragma once @@ -69,6 +67,4 @@ void csr_transpose(const raft::handle_t& handle, }; // end NAMESPACE linalg }; // end NAMESPACE sparse -}; // end NAMESPACE raft - -#endif \ No newline at end of file +}; // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 66067e4dfd..3192330639 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -213,6 +213,7 @@ if(BUILD_TESTS) test/sparse/reduce.cu test/sparse/row_op.cu test/sparse/sort.cu + test/sparse/spgemmi.cu test/sparse/symmetrize.cu ) diff --git a/cpp/test/sparse/csr_transpose.cu b/cpp/test/sparse/csr_transpose.cu index bea8f903cd..108d38a8b4 100644 --- a/cpp/test/sparse/csr_transpose.cu +++ b/cpp/test/sparse/csr_transpose.cu @@ -29,7 +29,6 @@ namespace raft { namespace sparse { using namespace raft; -using namespace raft::sparse; template struct CSRTransposeInputs { diff --git a/cpp/test/sparse/spgemmi.cu b/cpp/test/sparse/spgemmi.cu new file mode 100644 index 0000000000..a132c94fde --- /dev/null +++ b/cpp/test/sparse/spgemmi.cu @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../test_utils.h" + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { + +struct SPGemmiInputs { + int n_rows, n_cols; +}; + +template +class SPGemmiTest : public ::testing::TestWithParam { + public: + SPGemmiTest() + : params(::testing::TestWithParam::GetParam()), stream(handle.get_stream()) + { + } + + protected: + void SetUp() override {} + + void Run() + { + // Host problem definition + float alpha = 1.0f; + float beta = 0.0f; + int A_num_rows = 5; + int A_num_cols = 3; + // int B_num_rows = A_num_cols; + int B_num_cols = 4; + int B_nnz = 9; + int lda = A_num_rows; + int ldc = A_num_rows; + int A_size = lda * A_num_cols; + int C_size = ldc * B_num_cols; + int hB_cscOffsets[] = {0, 3, 4, 7, 9}; + int hB_rows[] = {0, 2, 3, 1, 0, 2, 3, 1, 3}; + float hB_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + float hA[] = {1.0f, + 2.0f, + 3.0f, + 4.0f, + 5.0f, + 6.0f, + 7.0f, + 8.0f, + 9.0f, + 10.0f, + 11.0f, + 12.0f, + 13.0f, + 14.0f, + 15.0f}; + std::vector hC(C_size); + std::vector hC_expected{23, 26, 29, 32, 35, 24, 28, 32, 36, 40, + 71, 82, 93, 104, 115, 48, 56, 64, 72, 80}; + //-------------------------------------------------------------------------- + // Device memory management + rmm::device_uvector dB_cscOffsets(B_num_cols + 1, stream); + rmm::device_uvector dB_rows(B_nnz, stream); + rmm::device_uvector dB_values(B_nnz, stream); + rmm::device_uvector dA(A_size, stream); + rmm::device_uvector dC(C_size, stream); + rmm::device_uvector dCT(C_size, stream); + + raft::update_device(dB_cscOffsets.data(), hB_cscOffsets, B_num_cols + 1, stream); + raft::update_device(dB_rows.data(), hB_rows, B_nnz, stream); + raft::update_device(dB_values.data(), hB_values, B_nnz, stream); + raft::update_device(dA.data(), hA, A_size, stream); + raft::update_device(dC.data(), hC.data(), C_size, stream); + + //-------------------------------------------------------------------------- + // execute gemmi + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsegemmi(handle.get_cusparse_handle(), + A_num_rows, + B_num_cols, + A_num_cols, + B_nnz, + &alpha, + dA.data(), + lda, + dB_values.data(), + dB_cscOffsets.data(), + dB_rows.data(), + &beta, + dC.data(), + ldc, + handle.get_stream())); + + //-------------------------------------------------------------------------- + // result check + raft::update_host(hC.data(), dC.data(), C_size, stream); + ASSERT_TRUE(hostVecMatch(hC_expected, hC, raft::Compare())); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + SPGemmiInputs params; +}; + +using SPGemmiTestF = SPGemmiTest; +TEST_P(SPGemmiTestF, Result) { Run(); } + +using SPGemmiTestD = SPGemmiTest; +TEST_P(SPGemmiTestD, Result) { Run(); } + +const std::vector csc_inputs_f = {{5, 4}}; +const std::vector csc_inputs_d = {{5, 4}}; + +INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestF, ::testing::ValuesIn(csc_inputs_f)); +INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestD, ::testing::ValuesIn(csc_inputs_d)); + +} // namespace sparse +} // namespace raft