From 4563c3b343132313920fc33034d0c30c5f6d2bd3 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 11 Jan 2024 15:06:56 -0800 Subject: [PATCH] add benchmark for SDDMM --- cpp/bench/prims/CMakeLists.txt | 3 +- cpp/bench/prims/linalg/sddmm.cu | 275 ++++++++++++++++++++++++++++++++ 2 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 cpp/bench/prims/linalg/sddmm.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index fe58453d0d..3a2431cd34 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/linalg/reduce_cols_by_key.cu bench/prims/linalg/reduce_rows_by_key.cu bench/prims/linalg/reduce.cu + bench/prims/linalg/sddmm.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/linalg/sddmm.cu b/cpp/bench/prims/linalg/sddmm.cu new file mode 100644 index 0000000000..139a2b838d --- /dev/null +++ b/cpp/bench/prims/linalg/sddmm.cu @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace raft::bench::linalg { + +template +struct SDDMMBenchParams { + size_t m; + size_t k; + size_t n; + float sparsity; + bool transpose_a; + bool transpose_b; + ValueType alpha = 1.0; + ValueType beta = 0.0; +}; + +enum Alg { SDDMM, Inner }; + +template +inline auto operator<<(std::ostream& os, const SDDMMBenchParams& params) -> std::ostream& +{ + os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n + << "\tsparsity=" << params.sparsity << "\ttrans_a=" << (params.transpose_a ? "T" : "F") + << " trans_b=" << (params.transpose_b ? "T" : "F"); + return os; +} + +template +struct SDDMMBench : public fixture { + SDDMMBench(const SDDMMBenchParams& p) + : fixture(true), + params(p), + handle(stream), + a_data_d(0, stream), + b_data_d(0, stream), + c_indptr_d(0, stream), + c_indices_d(0, stream), + c_data_d(0, stream), + c_dense_data_d(0, stream) + { + a_data_d.resize(params.m * params.k, stream); + b_data_d.resize(params.k * params.n, stream); + + raft::random::RngState rng(2024ULL); + raft::random::uniform( + handle, rng, a_data_d.data(), params.m * params.k, ValueType(-1.0), ValueType(1.0)); + raft::random::uniform( + handle, rng, b_data_d.data(), params.k * params.n, ValueType(-1.0), ValueType(1.0)); + + std::vector c_dense_data_h(params.m * params.n); + + c_true_nnz = create_sparse_matrix(c_dense_data_h); + std::vector values(c_true_nnz); + std::vector indices(c_true_nnz); + std::vector indptr(params.m + 1); + + c_data_d.resize(c_true_nnz, stream); + c_indptr_d.resize(params.m + 1, stream); + c_indices_d.resize(c_true_nnz, stream); + + if (SDDMMorInner == Alg::Inner) { c_dense_data_d.resize(params.m * params.n, stream); } + + convert_to_csr(c_dense_data_h, params.m, params.n, values, indices, indptr); + RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), + "Something wrong. The c_true_nnz != c_indices_d.size()!"); + + update_device(c_data_d.data(), values.data(), c_true_nnz, stream); + update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); + update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + size_t create_sparse_matrix(std::vector& matrix) + { + size_t total_elements = static_cast(params.m * params.n); + size_t num_ones = static_cast((total_elements * 1.0f) * params.sparsity); + size_t res = num_ones; + + for (size_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + ~SDDMMBench() {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + raft::resource::get_cusparse_handle(handle); + + resource::sync_stream(handle); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + + loop_on_state(state, [this, &a, &b, &c, &op_a, &op_b]() { + if (SDDMMorInner == Alg::SDDMM) { + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + std::is_same_v); + resource::sync_stream(handle); + } + }); + } + + private: + const raft::device_resources handle; + SDDMMBenchParams params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector c_dense_data_d; + + size_t c_true_nnz = 0; + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; +}; + +template +static std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + bool transpose_a; + bool transpose_b; + size_t m; + size_t k; + size_t n; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({false, true}, + {false, true}, + {size_t(10), size_t(1024)}, + {size_t(128), size_t(1024)}, + {size_t(1024 * 1024)}, + {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(SDDMMBenchParams( + {params.m, params.k, params.n, params.sparsity, params.transpose_a, params.transpose_b})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +} // namespace raft::bench::linalg