diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 51e1c41499..6b2d463d0e 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(${RAFT_CPP_BENCH_TARGET} bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu + bench/sparse/convert_csr.cu bench/spatial/fused_l2_nn.cu bench/spatial/knn.cu bench/spatial/selection.cu diff --git a/cpp/bench/sparse/convert_csr.cu b/cpp/bench/sparse/convert_csr.cu new file mode 100644 index 0000000000..0e701518ab --- /dev/null +++ b/cpp/bench/sparse/convert_csr.cu @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t num_cols; + index_t num_rows; + index_t divisor; +}; + +template +__global__ void init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor) +{ + index_t r = blockDim.y * blockIdx.y + threadIdx.y; + index_t c = blockDim.x * blockIdx.x + threadIdx.x; + + for (; r < num_rows; r += gridDim.y * blockDim.y) { + for (; c < num_cols; c += gridDim.x * blockDim.x) { + adj[r * num_cols + c] = c % divisor == 0; + } + } +} + +template +void init_adj(bool* adj, index_t num_rows, index_t num_cols, index_t divisor, cudaStream_t stream) +{ + // adj matrix: element a_ij is set to one if j is divisible by divisor. + dim3 block(32, 32); + const index_t max_y_grid_dim = 65535; + dim3 grid(num_cols / 32 + 1, (int)min(num_rows / 32 + 1, max_y_grid_dim)); + init_adj_kernel<<>>(adj, num_rows, num_cols, divisor); + RAFT_CHECK_CUDA(stream); +} + +template +struct bench_base : public fixture { + bench_base(const bench_param& p) + : params(p), + handle(stream), + adj(p.num_rows * p.num_cols, stream), + row_ind(p.num_rows, stream), + row_ind_host(p.num_rows), + row_counters(p.num_rows, stream), + // col_ind is over-dimensioned because nnz is unknown at this point + col_ind(p.num_rows * p.num_cols, stream) + { + init_adj(adj.data(), p.num_rows, p.num_cols, p.divisor, stream); + + std::vector row_ind_host(p.num_rows); + for (size_t i = 0; i < row_ind_host.size(); ++i) { + size_t nnz_per_row = raft::ceildiv(p.num_cols, p.divisor); + row_ind_host[i] = nnz_per_row * i; + } + raft::update_device(row_ind.data(), row_ind_host.data(), row_ind.size(), stream); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + raft::sparse::convert::adj_to_csr(handle, + adj.data(), + row_ind.data(), + params.num_rows, + params.num_cols, + row_counters.data(), + col_ind.data()); + }); + + // Estimate bandwidth: + index_t num_entries = params.num_rows * params.num_cols; + index_t bytes_read = num_entries * sizeof(bool); + index_t bytes_write = num_entries / params.divisor * sizeof(index_t); + + state.counters["BW"] = benchmark::Counter(bytes_read + bytes_write, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1024); + state.counters["BW read"] = benchmark::Counter( + bytes_read, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); + state.counters["BW write"] = benchmark::Counter(bytes_write, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1024); + + state.counters["Fraction nz"] = benchmark::Counter(100.0 / ((double)params.divisor)); + state.counters["Columns"] = benchmark::Counter(params.num_cols); + state.counters["Rows"] = benchmark::Counter(params.num_rows); + } + + protected: + raft::handle_t handle; + bench_param params; + rmm::device_uvector adj; + rmm::device_uvector row_ind; + std::vector row_ind_host; + rmm::device_uvector row_counters; + rmm::device_uvector col_ind; +}; // struct bench_base + +const int64_t num_cols = 1 << 30; + +const std::vector> bench_params = { + {num_cols, 1, 8}, + {num_cols >> 3, 1 << 3, 8}, + {num_cols >> 6, 1 << 6, 8}, + + {num_cols, 1, 64}, + {num_cols >> 3, 1 << 3, 64}, + {num_cols >> 6, 1 << 6, 64}, + + {num_cols, 1, 2048}, + {num_cols >> 3, 1 << 3, 2048}, + {num_cols >> 6, 1 << 6, 2048}, +}; + +RAFT_BENCH_REGISTER(bench_base, "", bench_params); + +} // namespace raft::bench::sparse diff --git a/cpp/doxygen/Doxyfile.in b/cpp/doxygen/Doxyfile.in index 593a62ef94..7f83154b5d 100644 --- a/cpp/doxygen/Doxyfile.in +++ b/cpp/doxygen/Doxyfile.in @@ -844,9 +844,8 @@ EXCLUDE_PATTERNS = */detail/* \ # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories use the pattern */test/* -EXCLUDE_SYMBOLS = detail \ - csr_adj_graph \ - csr_adj_graph_batched +EXCLUDE_SYMBOLS = detail + # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include diff --git a/cpp/include/raft/sparse/convert/coo.hpp b/cpp/include/raft/sparse/convert/coo.hpp index 697452db09..1c8a2625c7 100644 --- a/cpp/include/raft/sparse/convert/coo.hpp +++ b/cpp/include/raft/sparse/convert/coo.hpp @@ -18,34 +18,10 @@ * Please use the cuh version instead. */ -#ifndef __COO_H -#define __COO_H - #pragma once -#include - -namespace raft { -namespace sparse { -namespace convert { - -/** - * @brief Convert a CSR row_ind array to a COO rows array - * @param row_ind: Input CSR row_ind array - * @param m: size of row_ind array - * @param coo_rows: Output COO row array - * @param nnz: size of output COO row array - * @param stream: cuda stream to use - */ -template -void csr_to_coo( - const value_idx* row_ind, value_idx m, value_idx* coo_rows, value_idx nnz, cudaStream_t stream) -{ - detail::csr_to_coo(row_ind, m, coo_rows, nnz, stream); -} - -}; // end NAMESPACE convert -}; // end NAMESPACE sparse -}; // end NAMESPACE raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") -#endif \ No newline at end of file +#include "coo.cuh" diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 10bc22bcc1..abdacdc426 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include #include @@ -39,74 +40,6 @@ void coo_to_csr(const raft::handle_t& handle, detail::coo_to_csr(handle, srcRows, srcCols, srcVals, nnz, m, dst_offsets, dstCols, dstVals); } -/** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @tparam Lambda function for fused operation in the adj_graph construction - * @param row_ind the input CSR row_ind array - * @param total_rows number of vertices in graph - * @param nnz number of non-zeros - * @param batchSize number of vertices in current batch - * @param adj an adjacency array (size batchSize x total_rows) - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op: the fused operation - */ -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - detail::csr_adj_graph_batched( - row_ind, total_rows, nnz, batchSize, adj, row_ind_ptr, stream, fused_op); -} - -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream) -{ - detail::csr_adj_graph_batched( - row_ind, total_rows, nnz, batchSize, adj, row_ind_ptr, stream); -} - -/** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from a - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @param row_ind the input CSR row_ind array - * @param total_rows number of total vertices in graph - * @param nnz number of non-zeros - * @param adj an adjacency array - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op the fused operation - */ -template void> -void csr_adj_graph(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - detail::csr_adj_graph( - row_ind, total_rows, nnz, adj, row_ind_ptr, stream, fused_op); -} - /** * @brief Generate the row indices array for a sorted COO matrix * @@ -135,8 +68,42 @@ void sorted_coo_to_csr(COO* coo, int* row_ind, cudaStream_t stream) detail::sorted_coo_to_csr(coo->rows(), coo->nnz, row_ind, coo->n_rows, stream); } +/** + * @brief Converts a boolean adjacency matrix into unsorted CSR format. + * + * The conversion supports non-square matrices. + * + * @tparam index_t Indexing arithmetic type + * + * @param[in] handle RAFT handle + * @param[in] adj A num_rows x num_cols boolean matrix in contiguous row-major + * format. + * @param[in] row_ind An array of length num_rows that indicates at which index + * a row starts in out_col_ind. Equivalently, it is the + * exclusive scan of the number of non-zeros in each row of + * adj. + * @param[in] num_rows Number of rows of adj. + * @param[in] num_cols Number of columns of adj. + * @param tmp A pre-allocated array of size num_rows. + * @param[out] out_col_ind An array containing the column indices of the + * non-zero values in adj. Size should be at least the + * number of non-zeros in adj. + */ +template +void adj_to_csr(const raft::handle_t& handle, + const bool* adj, // Row-major adjacency matrix + const index_t* row_ind, // Precomputed row indices + index_t num_rows, // # rows of adj + index_t num_cols, // # cols of adj + index_t* tmp, // Pre-allocated atomic counters. Minimum size: num_rows elements. + index_t* out_col_ind // Output column indices +) +{ + detail::adj_to_csr(handle, adj, row_ind, num_rows, num_cols, tmp, out_col_ind); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/sparse/convert/csr.hpp b/cpp/include/raft/sparse/convert/csr.hpp index cd5d89bf71..250c6855d0 100644 --- a/cpp/include/raft/sparse/convert/csr.hpp +++ b/cpp/include/raft/sparse/convert/csr.hpp @@ -13,135 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * This file is deprecated and will be removed in release 22.06. * Please use the cuh version instead. */ -#ifndef __CSR_H -#define __CSR_H - -#pragma once - -#include -#include - -namespace raft { -namespace sparse { -namespace convert { - -template -void coo_to_csr(const raft::handle_t& handle, - const int* srcRows, - const int* srcCols, - const value_t* srcVals, - int nnz, - int m, - int* dst_offsets, - int* dstCols, - value_t* dstVals) -{ - detail::coo_to_csr(handle, srcRows, srcCols, srcVals, nnz, m, dst_offsets, dstCols, dstVals); -} - /** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @tparam Lambda function for fused operation in the adj_graph construction - * @param row_ind the input CSR row_ind array - * @param total_rows number of vertices in graph - * @param nnz number of non-zeros - * @param batchSize number of vertices in current batch - * @param adj an adjacency array (size batchSize x total_rows) - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op: the fused operation + * DISCLAIMER: this file is deprecated: use csr.cuh instead */ -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - detail::csr_adj_graph_batched( - row_ind, total_rows, nnz, batchSize, adj, row_ind_ptr, stream, fused_op); -} - -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream) -{ - detail::csr_adj_graph_batched( - row_ind, total_rows, nnz, batchSize, adj, row_ind_ptr, stream); -} -/** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from a - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @param row_ind the input CSR row_ind array - * @param total_rows number of total vertices in graph - * @param nnz number of non-zeros - * @param adj an adjacency array - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op the fused operation - */ -template void> -void csr_adj_graph(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - detail::csr_adj_graph( - row_ind, total_rows, nnz, adj, row_ind_ptr, stream, fused_op); -} - -/** - * @brief Generate the row indices array for a sorted COO matrix - * - * @param rows: COO rows array - * @param nnz: size of COO rows array - * @param row_ind: output row indices array - * @param m: number of rows in dense matrix - * @param stream: cuda stream to use - */ -template -void sorted_coo_to_csr(const T* rows, int nnz, T* row_ind, int m, cudaStream_t stream) -{ - detail::sorted_coo_to_csr(rows, nnz, row_ind, m, stream); -} - -/** - * @brief Generate the row indices array for a sorted COO matrix - * - * @param coo: Input COO matrix - * @param row_ind: output row indices array - * @param stream: cuda stream to use - */ -template -void sorted_coo_to_csr(COO* coo, int* row_ind, cudaStream_t stream) -{ - detail::sorted_coo_to_csr(coo->rows(), coo->nnz, row_ind, coo->n_rows, stream); -} +#pragma once -}; // end NAMESPACE convert -}; // end NAMESPACE sparse -}; // end NAMESPACE raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") -#endif \ No newline at end of file +#include "csr.cuh" diff --git a/cpp/include/raft/sparse/convert/dense.hpp b/cpp/include/raft/sparse/convert/dense.hpp index f8338536c8..2d2c527e7e 100644 --- a/cpp/include/raft/sparse/convert/dense.hpp +++ b/cpp/include/raft/sparse/convert/dense.hpp @@ -18,55 +18,10 @@ * Please use the cuh version instead. */ -#ifndef __DENSE_H -#define __DENSE_H - #pragma once -#include - -namespace raft { -namespace sparse { -namespace convert { - -/** - * Convert CSR arrays to a dense matrix in either row- - * or column-major format. A custom kernel is used when - * row-major output is desired since cusparse does not - * output row-major. - * @tparam value_idx : data type of the CSR index arrays - * @tparam value_t : data type of the CSR value array - * @param[in] handle : cusparse handle for conversion - * @param[in] nrows : number of rows in CSR - * @param[in] ncols : number of columns in CSR - * @param[in] nnz : number of nonzeros in CSR - * @param[in] csr_indptr : CSR row index pointer array - * @param[in] csr_indices : CSR column indices array - * @param[in] csr_data : CSR data array - * @param[in] lda : Leading dimension (used for col-major only) - * @param[out] out : Dense output array of size nrows * ncols - * @param[in] stream : Cuda stream for ordering events - * @param[in] row_major : Is row-major output desired? - */ -template -void csr_to_dense(cusparseHandle_t handle, - value_idx nrows, - value_idx ncols, - value_idx nnz, - const value_idx* csr_indptr, - const value_idx* csr_indices, - const value_t* csr_data, - value_idx lda, - value_t* out, - cudaStream_t stream, - bool row_major = true) -{ - detail::csr_to_dense( - handle, nrows, ncols, nnz, csr_indptr, csr_indices, csr_data, lda, out, stream, row_major); -} - -}; // end NAMESPACE convert -}; // end NAMESPACE sparse -}; // end NAMESPACE raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") -#endif \ No newline at end of file +#include "dense.cuh" diff --git a/cpp/include/raft/sparse/convert/detail/adj_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/adj_to_csr.cuh new file mode 100644 index 0000000000..e55627c936 --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/adj_to_csr.cuh @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +/** + * @brief Convert dense adjacency matrix into unsorted CSR format. + * + * The adj_to_csr kernel converts a boolean adjacency matrix into CSR + * format. High performance comes at the cost of non-deterministic output: the + * column indices are not guaranteed to be stored in order. + * + * The kernel has been optimized to handle matrices that are non-square, for + * instance subsets of a full adjacency matrix. In practice, these matrices can + * be very wide and not very tall. In principle, each row is assigned to one + * block. If there are more SMs than rows, multiple blocks operate on a single + * row. To enable cooperation between these blocks, each row is provided a + * counter where the current output index can be cooperatively (atomically) + * incremented. As a result, the order of the output indices is not guaranteed. + * + * @param[in] adj A num_rows x num_cols boolean matrix in contiguous row-major + * format. + * @param[in] row_ind An array of length num_rows that indicates at which index + * a row starts in out_col_ind. Equivalently, it is the + * exclusive scan of the number of non-zeros in each row of + * adj. + * @param[in] num_rows Number of rows of adj. + * @param[in] num_cols Number of columns of adj. + * @param[in,out] row_counters A temporary zero-initialized array of length num_rows. + * @param[out] out_col_ind An array containing the column indices of the + * non-zero values in `adj`. Size should be at least + * the number of non-zeros in `adj`. + */ +template +__global__ void adj_to_csr_kernel(const bool* adj, // row-major adjacency matrix + const index_t* row_ind, // precomputed row indices + index_t num_rows, // # rows of adj + index_t num_cols, // # cols of adj + index_t* row_counters, // pre-allocated (zeroed) atomic counters + index_t* out_col_ind // output column indices +) +{ + const int chunk_size = 16; + typedef raft::TxN_t chunk_bool; + + for (index_t i = blockIdx.y; i < num_rows; i += gridDim.y) { + // Load row information + index_t row_base = row_ind[i]; + index_t* row_count = row_counters + i; + const bool* row = adj + i * num_cols; + + // Peeling: process the first j0 elements that are not aligned to a chunk_size-byte + // boundary. + index_t j0 = (chunk_size - (((uintptr_t)(const void*)row) % chunk_size)) % chunk_size; + j0 = min(j0, num_cols); + if (threadIdx.x < j0 && blockIdx.x == 0) { + if (row[threadIdx.x]) { out_col_ind[row_base + atomicIncWarp(row_count)] = threadIdx.x; } + } + + // Process the rest of the row in chunk_size byte chunks starting at j0. + // This is a grid-stride loop. + index_t j = j0 + chunk_size * (blockIdx.x * blockDim.x + threadIdx.x); + for (; j + chunk_size - 1 < num_cols; j += chunk_size * (blockDim.x * gridDim.x)) { + chunk_bool chunk; + chunk.load(row, j); + for (int k = 0; k < chunk_size; ++k) { + if (chunk.val.data[k]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j + k; } + } + } + + // Remainder: process the last j1 bools in the row individually. + index_t j1 = (num_cols - j0) % chunk_size; + if (threadIdx.x < j1 && blockIdx.x == 0) { + int j = num_cols - j1 + threadIdx.x; + if (row[j]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j; } + } + } +} + +/** + * @brief Converts a boolean adjacency matrix into unsorted CSR format. + * + * The conversion supports non-square matrices. + * + * @tparam index_t Indexing arithmetic type + * + * @param[in] handle RAFT handle + * @param[in] adj A num_rows x num_cols boolean matrix in contiguous row-major + * format. + * @param[in] row_ind An array of length num_rows that indicates at which index + * a row starts in out_col_ind. Equivalently, it is the + * exclusive scan of the number of non-zeros in each row of + * adj. + * @param[in] num_rows Number of rows of adj. + * @param[in] num_cols Number of columns of adj. + * @param tmp A pre-allocated array of size num_rows. + * @param[out] out_col_ind An array containing the column indices of the + * non-zero values in adj. Size should be at least the + * number of non-zeros in adj. + */ +template +void adj_to_csr(const raft::handle_t& handle, + const bool* adj, // row-major adjacency matrix + const index_t* row_ind, // precomputed row indices + index_t num_rows, // # rows of adj + index_t num_cols, // # cols of adj + index_t* tmp, // pre-allocated atomic counters + index_t* out_col_ind // output column indices +) +{ + auto stream = handle.get_stream(); + + // Check inputs and return early if possible. + if (num_rows == 0 || num_cols == 0) { return; } + RAFT_EXPECTS(tmp != nullptr, "adj_to_csr: tmp workspace may not be null."); + + // Zero-fill a temporary vector that is be used by the kernel to keep track of + // the number of entries added to a row. + RAFT_CUDA_TRY(cudaMemsetAsync(tmp, 0, num_rows * sizeof(index_t), stream)); + + // Split the grid in the row direction (since each row can be processed + // independently). If the maximum number of active blocks (num_sms * + // occupancy) exceeds the number of rows, assign multiple blocks to a single + // row. + int threads_per_block = 1024; + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, adj_to_csr_kernel, threads_per_block, 0); + + index_t max_active_blocks = sm_count * blocks_per_sm; + index_t blocks_per_row = raft::ceildiv(max_active_blocks, num_rows); + index_t grid_rows = raft::ceildiv(max_active_blocks, blocks_per_row); + dim3 block(threads_per_block, 1); + dim3 grid(blocks_per_row, grid_rows); + + adj_to_csr_kernel + <<>>(adj, row_ind, num_rows, num_cols, tmp, out_col_ind); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/csr.cuh b/cpp/include/raft/sparse/convert/detail/csr.cuh index 2516d00533..d945a3c785 100644 --- a/cpp/include/raft/sparse/convert/detail/csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/csr.cuh @@ -74,95 +74,6 @@ void coo_to_csr(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaDeviceSynchronize()); } -/** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @tparam Lambda function for fused operation in the adj_graph construction - * @param row_ind the input CSR row_ind array - * @param total_rows number of vertices in graph - * @param nnz number of non-zeros - * @param batchSize number of vertices in current batch - * @param adj an adjacency array (size batchSize x total_rows) - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op: the fused operation - */ -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - op::csr_row_op( - row_ind, - batchSize, - nnz, - [fused_op, adj, total_rows, row_ind_ptr, batchSize, nnz] __device__( - Index_ row, Index_ start_idx, Index_ stop_idx) { - fused_op(row, start_idx, stop_idx); - Index_ k = 0; - for (Index_ i = 0; i < total_rows; i++) { - // @todo: uncoalesced mem accesses! - if (adj[batchSize * i + row]) { - row_ind_ptr[start_idx + k] = i; - k += 1; - } - } - }, - stream); -} - -template void> -void csr_adj_graph_batched(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - Index_ batchSize, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream) -{ - csr_adj_graph_batched(row_ind, - total_rows, - nnz, - batchSize, - adj, - row_ind_ptr, - stream, - [] __device__(Index_ row, Index_ start_idx, Index_ stop_idx) {}); -} - -/** - * @brief Constructs an adjacency graph CSR row_ind_ptr array from a - * a row_ind array and adjacency array. - * @tparam T the numeric type of the index arrays - * @tparam TPB_X the number of threads to use per block for kernels - * @param row_ind the input CSR row_ind array - * @param total_rows number of total vertices in graph - * @param nnz number of non-zeros - * @param adj an adjacency array - * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph - * @param stream cuda stream to use - * @param fused_op the fused operation - */ -template void> -void csr_adj_graph(const Index_* row_ind, - Index_ total_rows, - Index_ nnz, - const bool* adj, - Index_* row_ind_ptr, - cudaStream_t stream, - Lambda fused_op) -{ - csr_adj_graph_batched( - row_ind, total_rows, nnz, total_rows, adj, row_ind_ptr, stream, fused_op); -} - /** * @brief Generate the row indices array for a sorted COO matrix * diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index bce4686bf0..a217a90e19 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -16,7 +16,7 @@ #include "../test_utils.h" #include -#include +#include #include #include @@ -88,55 +88,108 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, SortedCOOToCSR, ::testing::ValuesI /******************************** adj graph ********************************/ -template +template +__global__ void init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor) +{ + index_t r = blockDim.y * blockIdx.y + threadIdx.y; + index_t c = blockDim.x * blockIdx.x + threadIdx.x; + + for (; r < num_rows; r += gridDim.y * blockDim.y) { + for (; c < num_cols; c += gridDim.x * blockDim.x) { + adj[r * num_cols + c] = c % divisor == 0; + } + } +} + +template +void init_adj(bool* adj, index_t num_rows, index_t num_cols, index_t divisor, cudaStream_t stream) +{ + // adj matrix: element a_ij is set to one if j is divisible by divisor. + dim3 block(32, 32); + const index_t max_y_grid_dim = 65535; + dim3 grid(num_cols / 32 + 1, (int)min(num_rows / 32 + 1, max_y_grid_dim)); + init_adj_kernel<<>>(adj, num_rows, num_cols, divisor); + RAFT_CHECK_CUDA(stream); +} + +template struct CSRAdjGraphInputs { - Index_ n_rows; - Index_ n_cols; - std::vector row_ind; - std::vector adj; // To avoid vector optimization - std::vector verify; + index_t n_rows; + index_t n_cols; + index_t divisor; }; -template -class CSRAdjGraphTest : public ::testing::TestWithParam> { +template +class CSRAdjGraphTest : public ::testing::TestWithParam> { public: CSRAdjGraphTest() - : params(::testing::TestWithParam>::GetParam()), - stream(handle.get_stream()), - row_ind(params.n_rows, stream), + : stream(handle.get_stream()), + params(::testing::TestWithParam>::GetParam()), adj(params.n_rows * params.n_cols, stream), - result(params.verify.size(), stream), - verify(params.verify.size(), stream) + row_ind(params.n_rows, stream), + row_counters(params.n_rows, stream), + col_ind(params.n_rows * params.n_cols, stream), + row_ind_host(params.n_rows) { } protected: - void SetUp() override { nnz = params.verify.size(); } + void SetUp() override + { + // Initialize adj matrix: element a_ij equals one if j is divisible by + // params.divisor. + init_adj(adj.data(), params.n_rows, params.n_cols, params.divisor, stream); + // Initialize row_ind + for (size_t i = 0; i < row_ind_host.size(); ++i) { + size_t nnz_per_row = raft::ceildiv(params.n_cols, params.divisor); + row_ind_host[i] = nnz_per_row * i; + } + raft::update_device(row_ind.data(), row_ind_host.data(), row_ind.size(), stream); + + // Initialize result to 1, so we can catch any errors. + RAFT_CUDA_TRY(cudaMemsetAsync(col_ind.data(), 1, col_ind.size() * sizeof(index_t), stream)); + } void Run() { - raft::update_device(row_ind.data(), params.row_ind.data(), params.n_rows, stream); - raft::update_device(adj.data(), - reinterpret_cast(params.adj.data()), - params.n_rows * params.n_cols, - stream); - raft::update_device(verify.data(), params.verify.data(), nnz, stream); - - convert::csr_adj_graph_batched( - row_ind.data(), params.n_cols, nnz, params.n_rows, adj.data(), result.data(), stream); - - ASSERT_TRUE( - raft::devArrMatch(verify.data(), result.data(), nnz, raft::Compare())); + convert::adj_to_csr(handle, + adj.data(), + row_ind.data(), + params.n_rows, + params.n_cols, + row_counters.data(), + col_ind.data()); + + std::vector col_ind_host(col_ind.size()); + raft::update_host(col_ind_host.data(), col_ind.data(), col_ind.size(), stream); + std::vector row_counters_host(params.n_rows); + raft::update_host(row_counters_host.data(), row_counters.data(), row_counters.size(), stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + // 1. Check that each row contains enough values + index_t nnz_per_row = raft::ceildiv(params.n_cols, params.divisor); + for (index_t i = 0; i < params.n_rows; ++i) { + ASSERT_EQ(row_counters_host[i], nnz_per_row) << "where i = " << i; + } + // 2. Check that all column indices are divisble by divisor + for (index_t i = 0; i < params.n_rows; ++i) { + index_t row_base = row_ind_host[i]; + for (index_t j = 0; j < nnz_per_row; ++j) { + ASSERT_EQ(0, col_ind_host[row_base + j] % params.divisor); + } + } } protected: raft::handle_t handle; cudaStream_t stream; - CSRAdjGraphInputs params; - Index_ nnz; - rmm::device_uvector row_ind, result, verify; + CSRAdjGraphInputs params; rmm::device_uvector adj; + rmm::device_uvector row_ind; + rmm::device_uvector row_counters; + rmm::device_uvector col_ind; + std::vector row_ind_host; }; using CSRAdjGraphTestI = CSRAdjGraphTest; @@ -145,19 +198,15 @@ TEST_P(CSRAdjGraphTestI, Result) { Run(); } using CSRAdjGraphTestL = CSRAdjGraphTest; TEST_P(CSRAdjGraphTestL, Result) { Run(); } -const std::vector> csradjgraph_inputs_i = { - {3, - 6, - {0, 3, 6}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {0, 1, 2, 0, 1, 2, 0, 1, 2}}, -}; +const std::vector> csradjgraph_inputs_i = {{10, 10, 2}}; const std::vector> csradjgraph_inputs_l = { - {3, - 6, - {0, 3, 6}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {0, 1, 2, 0, 1, 2, 0, 1, 2}}, + {0, 0, 2}, + {10, 10, 2}, + {64 * 1024 + 10, 2, 3}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 3}, // No peeling-remainder + {17, 16, 3}, // Check peeling-remainder + {18, 16, 3}, // Check peeling-remainder + {32 + 9, 33, 2}, // Check peeling-remainder }; INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest,