From 3ed9143cba7f6178244eb91f3db15bd1bd5d267f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 10 Nov 2022 20:05:09 +0100 Subject: [PATCH 1/8] Started gemmi fix --- cpp/include/raft/sparse/detail/cusparse_wrappers.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index c8e4229203..676e438aae 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -689,13 +689,13 @@ cusparseStatus_t cusparsegemmi( // NOLINT math_type)); // Create dense matrices CUSPARSE_CHECK(cusparseCreateDnMat( - &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_ROW)); + &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); CUSPARSE_CHECK(cusparseCreateDnMat( - &matC, m, n, ldc, static_cast(const_cast(C)), math_type, CUSPARSE_ORDER_ROW)); + &matC, m, n, ldc, static_cast(C), math_type, CUSPARSE_ORDER_COL)); cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2; + cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG1; size_t buffer_size = 0; CUSPARSE_CHECK(cusparsespmm_bufferSize( From c0b4b4dd87641221714a51f183350246a323de7d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 11 Nov 2022 17:34:20 +0100 Subject: [PATCH 2/8] Add convertion to csr --- .../raft/sparse/detail/cusparse_wrappers.h | 164 +++++++++++------- 1 file changed, 97 insertions(+), 67 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 676e438aae..9c316996d6 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -646,73 +646,6 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, /** @} */ #endif -/** - * @defgroup Gemmi cusparse gemmi operations - * @{ - */ -template -cusparseStatus_t cusparsegemmi( // NOLINT - cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const T* alpha, - const T* A, - int lda, - const T* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const T* beta, - T* C, - int ldc, - cudaStream_t stream) -{ - static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); - - cusparseDnMatDescr_t matA; - cusparseSpMatDescr_t matB; - cusparseDnMatDescr_t matC; - - auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; - // Create sparse matrix B - CUSPARSE_CHECK(cusparseCreateCsc(&matB, - k, - n, - nnz, - static_cast(const_cast(cscColPtrB)), - static_cast(const_cast(cscRowIndB)), - static_cast(const_cast(cscValB)), - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, - math_type)); - // Create dense matrices - CUSPARSE_CHECK(cusparseCreateDnMat( - &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); - CUSPARSE_CHECK(cusparseCreateDnMat( - &matC, m, n, ldc, static_cast(C), math_type, CUSPARSE_ORDER_COL)); - - cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG1; - size_t buffer_size = 0; - - CUSPARSE_CHECK(cusparsespmm_bufferSize( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); - buffer_size = buffer_size / sizeof(T); - rmm::device_uvector external_buffer(buffer_size, stream); - auto return_value = cusparsespmm( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream); - - // destroy matrix/vector descriptors - CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); - CUSPARSE_CHECK(cusparseDestroySpMat(matB)); - CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); - return return_value; -} -/** @} */ - /** * @defgroup csr2coo cusparse CSR to COO converter methods * @{ @@ -1175,6 +1108,103 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, /** @} */ +/** + * @defgroup Gemmi cusparse gemmi operations + * @{ + */ +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream) +{ + static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); + + cusparseDnMatDescr_t matA; + cusparseSpMatDescr_t matB; + cusparseDnMatDescr_t matC; + + auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + auto buffer_size = size_t{}; + // Convert CSC to CSR + rmm::device_uvector csrValB(nnz, stream); + rmm::device_uvector csrRowPtrB(k + 1, stream); + rmm::device_uvector csrColIndB(nnz, stream); + CUSPARSE_CHECK(cusparsecsr2csc_bufferSize(handle, + n, + k, + nnz, + cscValB, + cscColPtrB, + cscRowIndB, + csrValB.data(), + csrRowPtrB.data(), + csrColIndB.data(), + CUSPARSE_ACTION_NUMERIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + &buffer_size, + stream)); + rmm::device_uvector external_buffer(buffer_size, stream); + CUSPARSE_CHECK(cusparsecsr2csc(handle, + n, + k, + nnz, + cscValB, + cscColPtrB, + cscRowIndB, + csrValB.data(), + csrRowPtrB.data(), + csrColIndB.data(), + CUSPARSE_ACTION_NUMERIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + static_cast(external_buffer.data()), + stream)); + + // Create sparse matrix B + CUSPARSE_CHECK( + cusparsecreatecsr(&matB, k, n, nnz, csrRowPtrB.data(), csrColIndB.data(), csrValB.data())); + + // Create dense matrices + CUSPARSE_CHECK(cusparseCreateDnMat( + &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); + CUSPARSE_CHECK( + cusparseCreateDnMat(&matC, n, m, n, static_cast(C), math_type, CUSPARSE_ORDER_COL)); + + cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; + cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG1; + + CUSPARSE_CHECK(cusparsespmm_bufferSize( + handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); + // buffer_size = buffer_size / sizeof(T); + external_buffer.resize(buffer_size, stream); + auto ext_buf = static_cast(static_cast(external_buffer.data())); + auto return_value = + cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream); + + // destroy matrix/vector descriptors + CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); + CUSPARSE_CHECK(cusparseDestroySpMat(matB)); + CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); + + return return_value; +} +/** @} */ + } // namespace detail } // namespace sparse } // namespace raft From 2ef6cb76e7ea039a58304a11e45c5cb97b276780 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 15 Nov 2022 14:48:35 +0100 Subject: [PATCH 3/8] Add transpose at the end --- build.sh | 2 + .../raft/sparse/detail/cusparse_wrappers.h | 176 ++++++++---------- 2 files changed, 81 insertions(+), 97 deletions(-) diff --git a/build.sh b/build.sh index b48465922a..55e1c9a47a 100755 --- a/build.sh +++ b/build.sh @@ -65,6 +65,7 @@ CMAKE_LOG_LEVEL="" VERBOSE_FLAG="" BUILD_ALL_GPU_ARCH=0 BUILD_TESTS=OFF +BUILD_TYPE=Release BUILD_BENCH=OFF BUILD_STATIC_FAISS=OFF COMPILE_LIBRARIES=OFF @@ -336,6 +337,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has cmake -S ${REPODIR}/cpp -B ${LIBRAFT_BUILD_DIR} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ -DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \ -DRAFT_NVTX=${NVTX} \ diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 9c316996d6..185be3a2cf 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace raft { @@ -646,6 +647,84 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, /** @} */ #endif +/** + * @defgroup Gemmi cusparse gemmi operations + * @{ + */ +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream) +{ + static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); + + cusparseDnMatDescr_t matA; + cusparseSpMatDescr_t matB; + cusparseDnMatDescr_t matC; + rmm::device_uvector CT(m * n, stream); + + auto constexpr math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + // Create sparse matrix B + CUSPARSE_CHECK(cusparseCreateCsc(&matB, + k, + n, + nnz, + static_cast(const_cast(cscColPtrB)), + static_cast(const_cast(cscRowIndB)), + static_cast(const_cast(cscValB)), + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, + math_type)); + /** + * Create dense matrices. + * Note: Since this is replacing `cusparse_gemmi`, it assumes dense inputs are + * column-ordered + */ + CUSPARSE_CHECK(cusparseCreateDnMat( + &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); + CUSPARSE_CHECK(cusparseCreateDnMat( + &matC, n, m, n, static_cast(CT.data()), math_type, CUSPARSE_ORDER_COL)); + + printf("m=%d, n=%d, k=%d\n", m, n, k); + + auto opA = CUSPARSE_OPERATION_TRANSPOSE; + auto opB = CUSPARSE_OPERATION_TRANSPOSE; + auto alg = CUSPARSE_SPMM_CSR_ALG1; + auto buffer_size = std::size_t{}; + + CUSPARSE_CHECK(cusparsespmm_bufferSize( + handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); + buffer_size = buffer_size / sizeof(T); + rmm::device_uvector external_buffer(buffer_size, stream); + auto ext_buf = static_cast(static_cast(external_buffer.data())); + auto return_value = + cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream); + + raft::handle_t rhandle; + raft::linalg::transpose(rhandle, CT.data(), C, n, m, stream); + // destroy matrix/vector descriptors + CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); + CUSPARSE_CHECK(cusparseDestroySpMat(matB)); + CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); + + return return_value; +} +/** @} */ + /** * @defgroup csr2coo cusparse CSR to COO converter methods * @{ @@ -1108,103 +1187,6 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, /** @} */ -/** - * @defgroup Gemmi cusparse gemmi operations - * @{ - */ -template -cusparseStatus_t cusparsegemmi( // NOLINT - cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const T* alpha, - const T* A, - int lda, - const T* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const T* beta, - T* C, - int ldc, - cudaStream_t stream) -{ - static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); - - cusparseDnMatDescr_t matA; - cusparseSpMatDescr_t matB; - cusparseDnMatDescr_t matC; - - auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; - auto buffer_size = size_t{}; - // Convert CSC to CSR - rmm::device_uvector csrValB(nnz, stream); - rmm::device_uvector csrRowPtrB(k + 1, stream); - rmm::device_uvector csrColIndB(nnz, stream); - CUSPARSE_CHECK(cusparsecsr2csc_bufferSize(handle, - n, - k, - nnz, - cscValB, - cscColPtrB, - cscRowIndB, - csrValB.data(), - csrRowPtrB.data(), - csrColIndB.data(), - CUSPARSE_ACTION_NUMERIC, - CUSPARSE_INDEX_BASE_ZERO, - CUSPARSE_CSR2CSC_ALG1, - &buffer_size, - stream)); - rmm::device_uvector external_buffer(buffer_size, stream); - CUSPARSE_CHECK(cusparsecsr2csc(handle, - n, - k, - nnz, - cscValB, - cscColPtrB, - cscRowIndB, - csrValB.data(), - csrRowPtrB.data(), - csrColIndB.data(), - CUSPARSE_ACTION_NUMERIC, - CUSPARSE_INDEX_BASE_ZERO, - CUSPARSE_CSR2CSC_ALG1, - static_cast(external_buffer.data()), - stream)); - - // Create sparse matrix B - CUSPARSE_CHECK( - cusparsecreatecsr(&matB, k, n, nnz, csrRowPtrB.data(), csrColIndB.data(), csrValB.data())); - - // Create dense matrices - CUSPARSE_CHECK(cusparseCreateDnMat( - &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_COL)); - CUSPARSE_CHECK( - cusparseCreateDnMat(&matC, n, m, n, static_cast(C), math_type, CUSPARSE_ORDER_COL)); - - cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG1; - - CUSPARSE_CHECK(cusparsespmm_bufferSize( - handle, opB, opA, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); - // buffer_size = buffer_size / sizeof(T); - external_buffer.resize(buffer_size, stream); - auto ext_buf = static_cast(static_cast(external_buffer.data())); - auto return_value = - cusparsespmm(handle, opB, opA, alpha, matB, matA, beta, matC, alg, ext_buf, stream); - - // destroy matrix/vector descriptors - CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); - CUSPARSE_CHECK(cusparseDestroySpMat(matB)); - CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); - - return return_value; -} -/** @} */ - } // namespace detail } // namespace sparse } // namespace raft From 375e6fd72d9641699ffcd80291caf1318be6a7e3 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 15 Nov 2022 17:28:51 +0100 Subject: [PATCH 4/8] Add guards for gemmi and gtest --- .../raft/sparse/detail/cusparse_wrappers.h | 68 +++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/spgemmi.cu | 142 ++++++++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 cpp/test/sparse/spgemmi.cu diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 185be3a2cf..254da14c7e 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -651,6 +651,73 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, * @defgroup Gemmi cusparse gemmi operations * @{ */ +#if CUDART_VERSION < 12000 +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const float* alpha, + const float* A, + int lda, + const float* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const float* beta, + float* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseSgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const double* alpha, + const double* A, + int lda, + const double* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const double* beta, + double* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +#else // CUDART >= 12.0 template cusparseStatus_t cusparsegemmi( // NOLINT cusparseHandle_t handle, @@ -723,6 +790,7 @@ cusparseStatus_t cusparsegemmi( // NOLINT return return_value; } +#endif /** @} */ /** diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0f5ebabcb9..cb2ec066b3 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -216,6 +216,7 @@ if(BUILD_TESTS) test/sparse/reduce.cu test/sparse/row_op.cu test/sparse/sort.cu + test/sparse/spgemmi.cu test/sparse/symmetrize.cu ) diff --git a/cpp/test/sparse/spgemmi.cu b/cpp/test/sparse/spgemmi.cu new file mode 100644 index 0000000000..a132c94fde --- /dev/null +++ b/cpp/test/sparse/spgemmi.cu @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../test_utils.h" + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { + +struct SPGemmiInputs { + int n_rows, n_cols; +}; + +template +class SPGemmiTest : public ::testing::TestWithParam { + public: + SPGemmiTest() + : params(::testing::TestWithParam::GetParam()), stream(handle.get_stream()) + { + } + + protected: + void SetUp() override {} + + void Run() + { + // Host problem definition + float alpha = 1.0f; + float beta = 0.0f; + int A_num_rows = 5; + int A_num_cols = 3; + // int B_num_rows = A_num_cols; + int B_num_cols = 4; + int B_nnz = 9; + int lda = A_num_rows; + int ldc = A_num_rows; + int A_size = lda * A_num_cols; + int C_size = ldc * B_num_cols; + int hB_cscOffsets[] = {0, 3, 4, 7, 9}; + int hB_rows[] = {0, 2, 3, 1, 0, 2, 3, 1, 3}; + float hB_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + float hA[] = {1.0f, + 2.0f, + 3.0f, + 4.0f, + 5.0f, + 6.0f, + 7.0f, + 8.0f, + 9.0f, + 10.0f, + 11.0f, + 12.0f, + 13.0f, + 14.0f, + 15.0f}; + std::vector hC(C_size); + std::vector hC_expected{23, 26, 29, 32, 35, 24, 28, 32, 36, 40, + 71, 82, 93, 104, 115, 48, 56, 64, 72, 80}; + //-------------------------------------------------------------------------- + // Device memory management + rmm::device_uvector dB_cscOffsets(B_num_cols + 1, stream); + rmm::device_uvector dB_rows(B_nnz, stream); + rmm::device_uvector dB_values(B_nnz, stream); + rmm::device_uvector dA(A_size, stream); + rmm::device_uvector dC(C_size, stream); + rmm::device_uvector dCT(C_size, stream); + + raft::update_device(dB_cscOffsets.data(), hB_cscOffsets, B_num_cols + 1, stream); + raft::update_device(dB_rows.data(), hB_rows, B_nnz, stream); + raft::update_device(dB_values.data(), hB_values, B_nnz, stream); + raft::update_device(dA.data(), hA, A_size, stream); + raft::update_device(dC.data(), hC.data(), C_size, stream); + + //-------------------------------------------------------------------------- + // execute gemmi + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsegemmi(handle.get_cusparse_handle(), + A_num_rows, + B_num_cols, + A_num_cols, + B_nnz, + &alpha, + dA.data(), + lda, + dB_values.data(), + dB_cscOffsets.data(), + dB_rows.data(), + &beta, + dC.data(), + ldc, + handle.get_stream())); + + //-------------------------------------------------------------------------- + // result check + raft::update_host(hC.data(), dC.data(), C_size, stream); + ASSERT_TRUE(hostVecMatch(hC_expected, hC, raft::Compare())); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + SPGemmiInputs params; +}; + +using SPGemmiTestF = SPGemmiTest; +TEST_P(SPGemmiTestF, Result) { Run(); } + +using SPGemmiTestD = SPGemmiTest; +TEST_P(SPGemmiTestD, Result) { Run(); } + +const std::vector csc_inputs_f = {{5, 4}}; +const std::vector csc_inputs_d = {{5, 4}}; + +INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestF, ::testing::ValuesIn(csc_inputs_f)); +INSTANTIATE_TEST_CASE_P(SparseGemmi, SPGemmiTestD, ::testing::ValuesIn(csc_inputs_d)); + +} // namespace sparse +} // namespace raft From 2565a801058ebf26227cd7feb23e13a766245ed5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 15 Nov 2022 17:31:36 +0100 Subject: [PATCH 5/8] remove print --- cpp/include/raft/sparse/detail/cusparse_wrappers.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 254da14c7e..3bb2db7902 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -766,8 +766,6 @@ cusparseStatus_t cusparsegemmi( // NOLINT CUSPARSE_CHECK(cusparseCreateDnMat( &matC, n, m, n, static_cast(CT.data()), math_type, CUSPARSE_ORDER_COL)); - printf("m=%d, n=%d, k=%d\n", m, n, k); - auto opA = CUSPARSE_OPERATION_TRANSPOSE; auto opB = CUSPARSE_OPERATION_TRANSPOSE; auto alg = CUSPARSE_SPMM_CSR_ALG1; @@ -787,7 +785,6 @@ cusparseStatus_t cusparsegemmi( // NOLINT CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); CUSPARSE_CHECK(cusparseDestroySpMat(matB)); CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); - return return_value; } #endif From f107794425dca9e87c8f985cbbcfe4c76e443548 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Nov 2022 21:16:27 -0500 Subject: [PATCH 6/8] Not sure why this was failing... --- cpp/test/sparse/csr_transpose.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/sparse/csr_transpose.cu b/cpp/test/sparse/csr_transpose.cu index bea8f903cd..93a5a4970c 100644 --- a/cpp/test/sparse/csr_transpose.cu +++ b/cpp/test/sparse/csr_transpose.cu @@ -29,7 +29,7 @@ namespace raft { namespace sparse { using namespace raft; -using namespace raft::sparse; + template struct CSRTransposeInputs { From 1602067e46a9aae75921b773bf98f25aeef25216 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Nov 2022 21:17:34 -0500 Subject: [PATCH 7/8] Fixing style --- cpp/test/sparse/csr_transpose.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/sparse/csr_transpose.cu b/cpp/test/sparse/csr_transpose.cu index 93a5a4970c..108d38a8b4 100644 --- a/cpp/test/sparse/csr_transpose.cu +++ b/cpp/test/sparse/csr_transpose.cu @@ -30,7 +30,6 @@ namespace sparse { using namespace raft; - template struct CSRTransposeInputs { value_idx nrows; From bb9e16b742089d068b2536984416d5f88e614797 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 16 Nov 2022 13:19:13 +0100 Subject: [PATCH 8/8] Remove csr_transpose guards --- cpp/include/raft/sparse/linalg/transpose.cuh | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/transpose.cuh b/cpp/include/raft/sparse/linalg/transpose.cuh index fa0031aab6..ae527fe34c 100644 --- a/cpp/include/raft/sparse/linalg/transpose.cuh +++ b/cpp/include/raft/sparse/linalg/transpose.cuh @@ -13,8 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __TRANSPOSE_H -#define __TRANSPOSE_H #pragma once @@ -69,6 +67,4 @@ void csr_transpose(const raft::handle_t& handle, }; // end NAMESPACE linalg }; // end NAMESPACE sparse -}; // end NAMESPACE raft - -#endif \ No newline at end of file +}; // end NAMESPACE raft \ No newline at end of file