diff --git a/cpp/src/dbscan/adjgraph/algo.cuh b/cpp/src/dbscan/adjgraph/algo.cuh index 8a2065f9a5..a1978b429a 100644 --- a/cpp/src/dbscan/adjgraph/algo.cuh +++ b/cpp/src/dbscan/adjgraph/algo.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -23,35 +24,149 @@ #include "pack.h" #include -#include +#include +#include +#include +#include namespace ML { namespace Dbscan { namespace AdjGraph { namespace Algo { -static const int TPB_X = 256; +/** + * @brief Convert a boolean adjacency matrix into CSR format. + * + * The adj_to_csr kernel converts a boolean adjacency matrix into CSR format. + * High performance comes at the cost of non-deterministic output: the column + * indices are not guaranteed to be stored in order. + * + * The kernel has been optimized to handle matrices that are non-square, for + * instance subsets of a full adjacency matrix. In practice, these matrices can + * be very wide and not very tall. In principle, each row is assigned to one + * block. If there are more SMs than rows, multiple blocks operate on a single + * row. To enable cooperation between these blocks, each row is provided a + * counter where the current output index can be cooperatively (atomically) + * incremented. As a result, the order of the output indices is not guaranteed + * to be in order. + * + * @param[in] adj: a num_rows x num_cols boolean matrix in contiguous row-major + * format. + * + * @param[in] row_ind: an array of length num_rows that indicates at which index + * a row starts in out_col_ind. Equivalently, it is the + * exclusive scan of the number of non-zeros in each row of + * `adj`. + * + * @param[in] num_rows: number of rows of adj. + * @param[in] num_cols: number of columns of adj. + * + * @param[in,out] row_counters: a temporary zero-initialized array of length num_rows. + * + * @param[out] out_col_ind: an array containing the column indices of the + * non-zero values in `adj`. Size should be at least + * the number of non-zeros in `adj`. + */ +template +__global__ void adj_to_csr(const bool* adj, // row-major adjacency matrix + const index_t* row_ind, // precomputed row indices + index_t num_rows, // # rows of adj + index_t num_cols, // # cols of adj + index_t* row_counters, // pre-allocated (zeroed) atomic counters + index_t* out_col_ind // output column indices +) +{ + typedef raft::TxN_t bool16; + + for (index_t i = blockIdx.y; i < num_rows; i += gridDim.y) { + // Load row information + index_t row_base = row_ind[i]; + index_t* row_count = row_counters + i; + const bool* row = adj + i * num_cols; + + // Peeling: process the first j0 elements that are not aligned to a 16-byte + // boundary. + index_t j0 = (16 - (((uintptr_t)(const void*)row) % 16)) % 16; + j0 = min(j0, num_cols); + if (threadIdx.x < j0 && blockIdx.x == 0) { + if (row[threadIdx.x]) { out_col_ind[row_base + atomicIncWarp(row_count)] = threadIdx.x; } + } + + // Process the rest of the row in 16 byte chunks starting at j0. + // This is a grid-stride loop. + index_t j = j0 + 16 * (blockIdx.x * blockDim.x + threadIdx.x); + for (; j + 15 < num_cols; j += 16 * (blockDim.x * gridDim.x)) { + bool16 chunk; + chunk.load(row, j); + + for (int k = 0; k < 16; ++k) { + if (chunk.val.data[k]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j + k; } + } + } + + // Remainder: process the last j1 bools in the row individually. + index_t j1 = (num_cols - j0) % 16; + if (threadIdx.x < j1 && blockIdx.x == 0) { + int j = num_cols - j1 + threadIdx.x; + if (row[j]) { out_col_ind[row_base + atomicIncWarp(row_count)] = j; } + } + } +} /** - * Takes vertex degree array (vd) and CSR row_ind array (ex_scan) to produce the - * CSR row_ind_ptr array (adj_graph) + * @brief Converts a boolean adjacency matrix into CSR format. + * + * @tparam[Index_]: indexing arithmetic type + * @param[in] handle: raft::handle_t + * + * @param[in,out] data: A struct containing the adjacency matrix, its number of + * columns, and the vertex degrees. + * + * @param[in] batch_size: The number of rows of the adjacency matrix data.adj + * @param row_counters: A pre-allocated temporary buffer on the device. + * Must be able to contain at least `batch_size` elements. + * @param[in] stream: CUDA stream */ template void launcher(const raft::handle_t& handle, Pack data, Index_ batch_size, + Index_* row_counters, cudaStream_t stream) { - using namespace thrust; + Index_ num_rows = batch_size; + Index_ num_cols = data.N; + bool* adj = data.adj; // batch_size x N row-major adjacency matrix + // Compute the exclusive scan of the vertex degrees + using namespace thrust; device_ptr dev_vd = device_pointer_cast(data.vd); device_ptr dev_ex_scan = device_pointer_cast(data.ex_scan); + thrust::exclusive_scan(handle.get_thrust_policy(), dev_vd, dev_vd + batch_size, dev_ex_scan); + + // Zero-fill a temporary vector that can be used by the adj_to_csr kernel to + // keep track of the number of entries added to a row. + RAFT_CUDA_TRY(cudaMemsetAsync(row_counters, 0, batch_size * sizeof(Index_), stream)); - exclusive_scan(handle.get_thrust_policy(), dev_vd, dev_vd + batch_size, dev_ex_scan); + // Split the grid in the row direction (since each row can be processed + // independently). If the maximum number of active blocks (num_sms * + // occupancy) exceeds the number of rows, assign multiple blocks to a single + // row. + int threads_per_block = 1024; + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, adj_to_csr, threads_per_block, 0); - raft::sparse::convert::csr_adj_graph_batched( - data.ex_scan, data.N, data.adjnnz, batch_size, data.adj, data.adj_graph, stream); + Index_ max_active_blocks = sm_count * blocks_per_sm; + Index_ blocks_per_row = raft::ceildiv(max_active_blocks, num_rows); + Index_ grid_rows = raft::ceildiv(max_active_blocks, blocks_per_row); + dim3 block(threads_per_block, 1); + dim3 grid(blocks_per_row, grid_rows); + adj_to_csr<<>>( + adj, data.ex_scan, num_rows, num_cols, row_counters, data.adj_graph); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/dbscan/adjgraph/runner.cuh b/cpp/src/dbscan/adjgraph/runner.cuh index 1a728a1a06..978f3ed14b 100644 --- a/cpp/src/dbscan/adjgraph/runner.cuh +++ b/cpp/src/dbscan/adjgraph/runner.cuh @@ -34,12 +34,14 @@ void run(const raft::handle_t& handle, Index_ N, int algo, Index_ batch_size, + Index_* row_counters, cudaStream_t stream) { Pack data = {vd, adj, adj_graph, adjnnz, ex_scan, N}; switch (algo) { + // TODO: deprecate naive runner. cf #3414 case 0: Naive::launcher(handle, data, batch_size, stream); break; - case 1: Algo::launcher(handle, data, batch_size, stream); break; + case 1: Algo::launcher(handle, data, batch_size, row_counters, stream); break; default: ASSERT(false, "Incorrect algo passed! '%d'", algo); } } diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 4ebde16da5..ba14717b26 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -148,6 +148,7 @@ std::size_t run(const raft::handle_t& handle, std::size_t m_size = raft::alignTo(sizeof(bool), align); std::size_t vd_size = raft::alignTo(sizeof(Index_) * (batch_size + 1), align); std::size_t ex_scan_size = raft::alignTo(sizeof(Index_) * batch_size, align); + std::size_t row_cnt_size = raft::alignTo(sizeof(Index_) * batch_size, align); std::size_t labels_size = raft::alignTo(sizeof(Index_) * N, align); Index_ MAX_LABEL = std::numeric_limits::max(); @@ -160,7 +161,8 @@ std::size_t run(const raft::handle_t& handle, (unsigned long)batch_size); if (workspace == NULL) { - auto size = adj_size + core_pts_size + m_size + vd_size + ex_scan_size + 2 * labels_size; + auto size = + adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + 2 * labels_size; return size; } @@ -179,6 +181,8 @@ std::size_t run(const raft::handle_t& handle, temp += vd_size; Index_* ex_scan = (Index_*)temp; temp += ex_scan_size; + Index_* row_counters = (Index_*)temp; + temp += row_cnt_size; Index_* labels_temp = (Index_*)temp; temp += labels_size; Index_* work_buffer = (Index_*)temp; @@ -237,8 +241,18 @@ std::size_t run(const raft::handle_t& handle, maxadjlen = curradjlen; adj_graph.resize(maxadjlen, stream); } - AdjGraph::run( - handle, adj, vd, adj_graph.data(), curradjlen, ex_scan, N, algo_adj, n_points, stream); + AdjGraph::run(handle, + adj, + vd, + adj_graph.data(), + curradjlen, + ex_scan, + N, + algo_adj, + n_points, + row_counters, + stream); + raft::common::nvtx::pop_range(); CUML_LOG_DEBUG("--> Computing connected components"); diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index 8cb13e8c4d..216e009a3e 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -16,15 +16,16 @@ #pragma once +#include "pack.h" #include #include +#include +#include #include #include #include #include -#include "pack.h" - namespace ML { namespace Dbscan { namespace VertexDeg { @@ -41,7 +42,11 @@ void launcher(const raft::handle_t& handle, cudaStream_t stream, raft::distance::DistanceType metric) { - data.resetArray(stream, batch_size + 1); + // The last position of data.vd is the sum of all elements in this array + // (excluding it). Hence, its length is one more than the number of points + // Initialize it to zero. + index_t* d_nnz = data.vd + batch_size; + RAFT_CUDA_TRY(cudaMemsetAsync(d_nnz, 0, sizeof(index_t), stream)); ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); @@ -50,6 +55,7 @@ void launcher(const raft::handle_t& handle, index_t k = data.D; value_t eps2; + // Compute adjacency matrix `adj` using Cosine or L2 metric. if (metric == raft::distance::DistanceType::CosineExpanded) { rmm::device_uvector rowNorms(m, stream); @@ -79,7 +85,7 @@ void launcher(const raft::handle_t& handle, eps2 = 2 * data.eps; raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + data.adj, nullptr, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream); /** * Restoring the input matrix after normalization. @@ -94,14 +100,32 @@ void launcher(const raft::handle_t& handle, true, [] __device__(value_t mat_in, value_t vec_in) { return mat_in * vec_in; }, stream); - } - - else { + } else { eps2 = data.eps * data.eps; + // 1. The output matrix adj is now an n x m matrix (row-major order) + // 2. Do not compute the vertex degree in epsUnexpL2SqNeighborhood (pass a + // nullptr) raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + data.adj, nullptr, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream); } + + // Reduction of adj to compute the vertex degrees + raft::linalg::coalescedReduction( + data.vd, + data.adj, + data.N, + batch_size, + (index_t)0, + stream, + false, + [] __device__(bool adj_ij, index_t idx) { return static_cast(adj_ij); }, + raft::Sum(), + [d_nnz] __device__(index_t degree) { + atomicAdd(d_nnz, degree); + return degree; + }); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace Algo diff --git a/cpp/src/dbscan/vertexdeg/precomputed.cuh b/cpp/src/dbscan/vertexdeg/precomputed.cuh index 3cead4bac8..3fa3a828bc 100644 --- a/cpp/src/dbscan/vertexdeg/precomputed.cuh +++ b/cpp/src/dbscan/vertexdeg/precomputed.cuh @@ -51,51 +51,44 @@ void launcher(const raft::handle_t& handle, index_t batch_size, cudaStream_t stream) { - const value_t& eps = data.eps; - - // Note: the matrix is symmetric. We take advantage of this to have two - // coalesced kernels: - // - The reduction works on a column-major N*B matrix to compute a B vector - // with a cub-BlockReduce-based primitive. - // The final_op is used to compute the total number of non-zero elements - // - The conversion to a boolean matrix works on a column-major B*N matrix - // (coalesced 2d copy + transform). - // - // If we end up supporting distributed distance matrices for MNMG, we can't - // rely on this trick anymore and need either a transposed kernel or to - // change the output layout. - // Regarding index types, a special index type is used here for indices in // the distance matrix due to its dimensions (that are independent of the // batch size) using long_index_t = long long int; - // Reduction to compute the vertex degrees + // The last position of data.vd is the sum of all elements in this array + // (excluding it). Hence, its length is one more than the number of points + // Initialize it to zero. index_t* d_nnz = data.vd + batch_size; RAFT_CUDA_TRY(cudaMemsetAsync(d_nnz, 0, sizeof(index_t), stream)); - raft::linalg::coalescedReduction( + + long_index_t N = data.N; + long_index_t cur_batch_size = min(data.N - start_vertex_id, batch_size); + + const value_t& eps = data.eps; + raft::linalg::unaryOp( + data.adj, + data.x + (long_index_t)start_vertex_id * N, + cur_batch_size * N, + [eps] __device__(value_t dist) { return (dist <= eps); }, + stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // Reduction of adj to compute the vertex degrees + raft::linalg::coalescedReduction( data.vd, - data.x + start_vertex_id * data.N, + data.adj, data.N, batch_size, (index_t)0, stream, false, - [eps] __device__(value_t dist, long_index_t idx) { return static_cast(dist <= eps); }, + [] __device__(bool adj_ij, long_index_t idx) { return static_cast(adj_ij); }, raft::Sum(), [d_nnz] __device__(index_t degree) { atomicAdd(d_nnz, degree); return degree; }); - - // Transform the distance matrix into a neighborhood matrix - dist_to_adj_kernel<<>>( - data.x, - data.adj, - (long_index_t)data.N, - (long_index_t)start_vertex_id, - (long_index_t)batch_size, - data.eps); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 561c98ab12..ea1da4f4e0 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -41,6 +41,7 @@ void run(const raft::handle_t& handle, { Pack data = {vd, adj, x, eps, N, D}; switch (algo) { + // TODO: deprecate naive runner. cf #3414 case 0: Naive::launcher(data, start_vertex_id, batch_size, stream); break; case 1: Algo::launcher(handle, data, start_vertex_id, batch_size, stream, metric);