From 2fad4c0cdf91780f36cb30b5f2bcec6632646382 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Sat, 23 Jul 2022 01:21:14 +0200 Subject: [PATCH] Accelerate adjacency matrix to CSR conversion for DBSCAN (#4803) Fixes issue #2387. For large data sizes, the batch size of the DBSCAN algorithm is small in order to fit the distance matrix in memory. This results in a matrix that has dimensions num_points x batch_size, both for the distance and adjacency matrix. The conversion of the boolean adjacency matrix to CSR format is performed in the 'adjgraph' step. This step was slow when the batch size was small, as described in issue #2387. In this commit, the adjgraph step is sped up. This is done in two ways: 1. The adjacency matrix is now stored in row-major batch_size x num_points format --- it was transposed before. This required changes in the vertexdeg step. 2. The csr_row_op kernel has been replaced by the adj_to_csr kernel. This kernel can divide the work over multiple blocks even when the number of rows (batch size) is small. It makes optimal use of memory bandwidth because rows of the matrix are laid out contiguously in memory. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuml/pull/4803 --- cpp/src/dbscan/adjgraph/algo.cuh | 131 +++++++++++++++++++++-- cpp/src/dbscan/adjgraph/runner.cuh | 4 +- cpp/src/dbscan/runner.cuh | 20 +++- cpp/src/dbscan/vertexdeg/algo.cuh | 40 +++++-- cpp/src/dbscan/vertexdeg/precomputed.cuh | 47 ++++---- cpp/src/dbscan/vertexdeg/runner.cuh | 1 + 6 files changed, 196 insertions(+), 47 deletions(-) diff --git a/cpp/src/dbscan/adjgraph/algo.cuh b/cpp/src/dbscan/adjgraph/algo.cuh index 8a2065f9a5..a1978b429a 100644 --- a/cpp/src/dbscan/adjgraph/algo.cuh +++ b/cpp/src/dbscan/adjgraph/algo.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -23,35 +24,149 @@ #include "pack.h" #include -#include +#include +#include +#include +#include namespace ML { namespace Dbscan { namespace AdjGraph { namespace Algo { -static const int TPB_X = 256; +/** + * @brief Convert a boolean adjacency matrix into CSR format. + * + * The adj_to_csr kernel converts a boolean adjacency matrix into CSR format. + * High performance comes at the cost of non-deterministic output: the column + * indices are not guaranteed to be stored in order. + * + * The kernel has been optimized to handle matrices that are non-square, for + * instance subsets of a full adjacency matrix. In practice, these matrices can + * be very wide and not very tall. In principle, each row is assigned to one + * block. If there are more SMs than rows, multiple blocks operate on a single + * row. To enable cooperation between these blocks, each row is provided a + * counter where the current output index can be cooperatively (atomically) + * incremented. As a result, the order of the output indices is not guaranteed + * to be in order. + * + * @param[in] adj: a num_rows x num_cols boolean matrix in contiguous row-major + * format. + * + * @param[in] row_ind: an array of length num_rows that indicates at which index + * a row starts in out_col_ind. Equivalently, it is the + * exclusive scan of the number of non-zeros in each row of + * `adj`. + * + * @param[in] num_rows: number of rows of adj. + * @param[in] num_cols: number of columns of adj. + * + * @param[in,out] row_counters: a temporary zero-initialized array of length num_rows. + * + * @param[out] out_col_ind: an array containing the column indices of the + * non-zero values in `adj`. Size should be at least + * the number of non-zeros in `adj`. + */ +template +__global__ void adj_to_csr(const bool* adj, // row-major adjacency matrix + const index_t* row_ind, // precomputed row indices + index_t num_rows, // # rows of adj + index_t num_cols, // # cols of adj + index_t* row_counters, // pre-allocated (zeroed) atomic counters + index_t* out_col_ind // output column indices +) +{ + typedef raft::TxN_t bool16; + + for (index_t i = blockIdx.y; i < num_rows; i += gridDim.y) { + // Load row information + index_t row_base = row_ind[i]; + index_t* row_count = row_counters + i; + const bool* row = adj + i * num_cols; + + // Peeling: process the first j0 elements that are not aligned to a 16-byte + // boundary. + index_t j0 = (16 - (((uintptr_t)(const void*)row) % 16)) % 16; + j0 = min(j0, num_cols); + if (threadIdx.x < j0 && blockIdx.x == 0) { + if (row[threadIdx.x]) { out_col_ind[row_base + atomicIncWarp(row_count)] = threadIdx.x; } + } + + // Process the rest of the row in 16 byte chunks starting at j0. + // This is a grid-stride loop. + index_t j = j0 + 16 * (blockIdx.x * blockDim.x + threadIdx.x); + for (; j + 15 < num_cols; j += 16 * (blockDim.x * gridDim.x)) { + bool16 chunk; + chunk.load(row, j); + + for (int k = 0; k < 16; ++k) { + if (chunk.val.data[k]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j + k; } + } + } + + // Remainder: process the last j1 bools in the row individually. + index_t j1 = (num_cols - j0) % 16; + if (threadIdx.x < j1 && blockIdx.x == 0) { + int j = num_cols - j1 + threadIdx.x; + if (row[j]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j; } + } + } +} /** - * Takes vertex degree array (vd) and CSR row_ind array (ex_scan) to produce the - * CSR row_ind_ptr array (adj_graph) + * @brief Converts a boolean adjacency matrix into CSR format. + * + * @tparam[Index_]: indexing arithmetic type + * @param[in] handle: raft::handle_t + * + * @param[in,out] data: A struct containing the adjacency matrix, its number of + * columns, and the vertex degrees. + * + * @param[in] batch_size: The number of rows of the adjacency matrix data.adj + * @param row_counters: A pre-allocated temporary buffer on the device. + * Must be able to contain at least `batch_size` elements. + * @param[in] stream: CUDA stream */ template void launcher(const raft::handle_t& handle, Pack data, Index_ batch_size, + Index_* row_counters, cudaStream_t stream) { - using namespace thrust; + Index_ num_rows = batch_size; + Index_ num_cols = data.N; + bool* adj = data.adj; // batch_size x N row-major adjacency matrix + // Compute the exclusive scan of the vertex degrees + using namespace thrust; device_ptr dev_vd = device_pointer_cast(data.vd); device_ptr dev_ex_scan = device_pointer_cast(data.ex_scan); + thrust::exclusive_scan(handle.get_thrust_policy(), dev_vd, dev_vd + batch_size, dev_ex_scan); + + // Zero-fill a temporary vector that can be used by the adj_to_csr kernel to + // keep track of the number of entries added to a row. + RAFT_CUDA_TRY(cudaMemsetAsync(row_counters, 0, batch_size * sizeof(Index_), stream)); - exclusive_scan(handle.get_thrust_policy(), dev_vd, dev_vd + batch_size, dev_ex_scan); + // Split the grid in the row direction (since each row can be processed + // independently). If the maximum number of active blocks (num_sms * + // occupancy) exceeds the number of rows, assign multiple blocks to a single + // row. + int threads_per_block = 1024; + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, adj_to_csr, threads_per_block, 0); - raft::sparse::convert::csr_adj_graph_batched( - data.ex_scan, data.N, data.adjnnz, batch_size, data.adj, data.adj_graph, stream); + Index_ max_active_blocks = sm_count * blocks_per_sm; + Index_ blocks_per_row = raft::ceildiv(max_active_blocks, num_rows); + Index_ grid_rows = raft::ceildiv(max_active_blocks, blocks_per_row); + dim3 block(threads_per_block, 1); + dim3 grid(blocks_per_row, grid_rows); + adj_to_csr<<>>( + adj, data.ex_scan, num_rows, num_cols, row_counters, data.adj_graph); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/dbscan/adjgraph/runner.cuh b/cpp/src/dbscan/adjgraph/runner.cuh index 1a728a1a06..978f3ed14b 100644 --- a/cpp/src/dbscan/adjgraph/runner.cuh +++ b/cpp/src/dbscan/adjgraph/runner.cuh @@ -34,12 +34,14 @@ void run(const raft::handle_t& handle, Index_ N, int algo, Index_ batch_size, + Index_* row_counters, cudaStream_t stream) { Pack data = {vd, adj, adj_graph, adjnnz, ex_scan, N}; switch (algo) { + // TODO: deprecate naive runner. cf #3414 case 0: Naive::launcher(handle, data, batch_size, stream); break; - case 1: Algo::launcher(handle, data, batch_size, stream); break; + case 1: Algo::launcher(handle, data, batch_size, row_counters, stream); break; default: ASSERT(false, "Incorrect algo passed! '%d'", algo); } } diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 4ebde16da5..ba14717b26 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -148,6 +148,7 @@ std::size_t run(const raft::handle_t& handle, std::size_t m_size = raft::alignTo(sizeof(bool), align); std::size_t vd_size = raft::alignTo(sizeof(Index_) * (batch_size + 1), align); std::size_t ex_scan_size = raft::alignTo(sizeof(Index_) * batch_size, align); + std::size_t row_cnt_size = raft::alignTo(sizeof(Index_) * batch_size, align); std::size_t labels_size = raft::alignTo(sizeof(Index_) * N, align); Index_ MAX_LABEL = std::numeric_limits::max(); @@ -160,7 +161,8 @@ std::size_t run(const raft::handle_t& handle, (unsigned long)batch_size); if (workspace == NULL) { - auto size = adj_size + core_pts_size + m_size + vd_size + ex_scan_size + 2 * labels_size; + auto size = + adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + 2 * labels_size; return size; } @@ -179,6 +181,8 @@ std::size_t run(const raft::handle_t& handle, temp += vd_size; Index_* ex_scan = (Index_*)temp; temp += ex_scan_size; + Index_* row_counters = (Index_*)temp; + temp += row_cnt_size; Index_* labels_temp = (Index_*)temp; temp += labels_size; Index_* work_buffer = (Index_*)temp; @@ -237,8 +241,18 @@ std::size_t run(const raft::handle_t& handle, maxadjlen = curradjlen; adj_graph.resize(maxadjlen, stream); } - AdjGraph::run( - handle, adj, vd, adj_graph.data(), curradjlen, ex_scan, N, algo_adj, n_points, stream); + AdjGraph::run(handle, + adj, + vd, + adj_graph.data(), + curradjlen, + ex_scan, + N, + algo_adj, + n_points, + row_counters, + stream); + raft::common::nvtx::pop_range(); CUML_LOG_DEBUG("--> Computing connected components"); diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index 8cb13e8c4d..216e009a3e 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -16,15 +16,16 @@ #pragma once +#include "pack.h" #include #include +#include +#include #include #include #include #include -#include "pack.h" - namespace ML { namespace Dbscan { namespace VertexDeg { @@ -41,7 +42,11 @@ void launcher(const raft::handle_t& handle, cudaStream_t stream, raft::distance::DistanceType metric) { - data.resetArray(stream, batch_size + 1); + // The last position of data.vd is the sum of all elements in this array + // (excluding it). Hence, its length is one more than the number of points + // Initialize it to zero. + index_t* d_nnz = data.vd + batch_size; + RAFT_CUDA_TRY(cudaMemsetAsync(d_nnz, 0, sizeof(index_t), stream)); ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); @@ -50,6 +55,7 @@ void launcher(const raft::handle_t& handle, index_t k = data.D; value_t eps2; + // Compute adjacency matrix `adj` using Cosine or L2 metric. if (metric == raft::distance::DistanceType::CosineExpanded) { rmm::device_uvector rowNorms(m, stream); @@ -79,7 +85,7 @@ void launcher(const raft::handle_t& handle, eps2 = 2 * data.eps; raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + data.adj, nullptr, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream); /** * Restoring the input matrix after normalization. @@ -94,14 +100,32 @@ void launcher(const raft::handle_t& handle, true, [] __device__(value_t mat_in, value_t vec_in) { return mat_in * vec_in; }, stream); - } - - else { + } else { eps2 = data.eps * data.eps; + // 1. The output matrix adj is now an n x m matrix (row-major order) + // 2. Do not compute the vertex degree in epsUnexpL2SqNeighborhood (pass a + // nullptr) raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + data.adj, nullptr, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream); } + + // Reduction of adj to compute the vertex degrees + raft::linalg::coalescedReduction( + data.vd, + data.adj, + data.N, + batch_size, + (index_t)0, + stream, + false, + [] __device__(bool adj_ij, index_t idx) { return static_cast(adj_ij); }, + raft::Sum(), + [d_nnz] __device__(index_t degree) { + atomicAdd(d_nnz, degree); + return degree; + }); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace Algo diff --git a/cpp/src/dbscan/vertexdeg/precomputed.cuh b/cpp/src/dbscan/vertexdeg/precomputed.cuh index 3cead4bac8..3fa3a828bc 100644 --- a/cpp/src/dbscan/vertexdeg/precomputed.cuh +++ b/cpp/src/dbscan/vertexdeg/precomputed.cuh @@ -51,51 +51,44 @@ void launcher(const raft::handle_t& handle, index_t batch_size, cudaStream_t stream) { - const value_t& eps = data.eps; - - // Note: the matrix is symmetric. We take advantage of this to have two - // coalesced kernels: - // - The reduction works on a column-major N*B matrix to compute a B vector - // with a cub-BlockReduce-based primitive. - // The final_op is used to compute the total number of non-zero elements - // - The conversion to a boolean matrix works on a column-major B*N matrix - // (coalesced 2d copy + transform). - // - // If we end up supporting distributed distance matrices for MNMG, we can't - // rely on this trick anymore and need either a transposed kernel or to - // change the output layout. - // Regarding index types, a special index type is used here for indices in // the distance matrix due to its dimensions (that are independent of the // batch size) using long_index_t = long long int; - // Reduction to compute the vertex degrees + // The last position of data.vd is the sum of all elements in this array + // (excluding it). Hence, its length is one more than the number of points + // Initialize it to zero. index_t* d_nnz = data.vd + batch_size; RAFT_CUDA_TRY(cudaMemsetAsync(d_nnz, 0, sizeof(index_t), stream)); - raft::linalg::coalescedReduction( + + long_index_t N = data.N; + long_index_t cur_batch_size = min(data.N - start_vertex_id, batch_size); + + const value_t& eps = data.eps; + raft::linalg::unaryOp( + data.adj, + data.x + (long_index_t)start_vertex_id * N, + cur_batch_size * N, + [eps] __device__(value_t dist) { return (dist <= eps); }, + stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // Reduction of adj to compute the vertex degrees + raft::linalg::coalescedReduction( data.vd, - data.x + start_vertex_id * data.N, + data.adj, data.N, batch_size, (index_t)0, stream, false, - [eps] __device__(value_t dist, long_index_t idx) { return static_cast(dist <= eps); }, + [] __device__(bool adj_ij, long_index_t idx) { return static_cast(adj_ij); }, raft::Sum(), [d_nnz] __device__(index_t degree) { atomicAdd(d_nnz, degree); return degree; }); - - // Transform the distance matrix into a neighborhood matrix - dist_to_adj_kernel<<>>( - data.x, - data.adj, - (long_index_t)data.N, - (long_index_t)start_vertex_id, - (long_index_t)batch_size, - data.eps); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 561c98ab12..ea1da4f4e0 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -41,6 +41,7 @@ void run(const raft::handle_t& handle, { Pack data = {vd, adj, x, eps, N, D}; switch (algo) { + // TODO: deprecate naive runner. cf #3414 case 0: Naive::launcher(data, start_vertex_id, batch_size, stream); break; case 1: Algo::launcher(handle, data, start_vertex_id, batch_size, stream, metric);