From caa44e6b0a5f34b35d72935e1bd5b4266a270368 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Wed, 23 Jun 2021 10:28:43 -0700 Subject: [PATCH] Sparse semirings cleanup + hash table & batching strategies (#269) This branch includes several new features and optimizations: 1. Introduces a hash table strategy to sparsify the vector in the coo spmv shared memory 2. Adds a batching strategy for rows with nnz too large to fit into shared memory 3. Removes the need for the cusparse csrgemm 4. Uses raft handle in distances_config_t rather than accepting each resource explicitly 5. Removes the naive CSR semiring code This PR is also required to merge #261, which introduces the remaining distances Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/269 --- cpp/cmake/thirdparty/get_cuco.cmake | 2 +- cpp/include/raft/linalg/distance_type.h | 10 +- .../raft/sparse/distance/bin_distance.cuh | 21 +- cpp/include/raft/sparse/distance/common.h | 10 +- cpp/include/raft/sparse/distance/coo_spmv.cuh | 300 ++------- .../coo_spmv_strategies/base_strategy.cuh | 91 +++ .../coo_mask_row_iterators.cuh | 192 ++++++ .../dense_smem_strategy.cuh | 97 +++ .../coo_spmv_strategies/hash_strategy.cuh | 222 +++++++ cpp/include/raft/sparse/distance/csr_spmv.cuh | 486 -------------- .../distance/detail/coo_spmv_kernel.cuh | 208 ++++++ .../raft/sparse/distance/ip_distance.cuh | 252 +------- .../raft/sparse/distance/l2_distance.cuh | 82 ++- .../raft/sparse/distance/lp_distance.cuh | 57 +- .../raft/sparse/distance/operators.cuh | 7 + cpp/include/raft/sparse/distance/utils.cuh | 46 ++ cpp/include/raft/sparse/selection/knn.cuh | 105 ++- cpp/test/CMakeLists.txt | 2 +- cpp/test/sparse/dist_coo_spmv.cu | 297 +++++---- cpp/test/sparse/dist_csr_spmv.cu | 612 ------------------ cpp/test/sparse/distance.cu | 67 +- cpp/test/sparse/knn.cu | 79 +-- 22 files changed, 1309 insertions(+), 1936 deletions(-) create mode 100644 cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh create mode 100644 cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh create mode 100644 cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh create mode 100644 cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh delete mode 100644 cpp/include/raft/sparse/distance/csr_spmv.cuh create mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh create mode 100644 cpp/include/raft/sparse/distance/utils.cuh delete mode 100644 cpp/test/sparse/dist_csr_spmv.cu diff --git a/cpp/cmake/thirdparty/get_cuco.cmake b/cpp/cmake/thirdparty/get_cuco.cmake index b9542a42f2..94ce6abf7d 100644 --- a/cpp/cmake/thirdparty/get_cuco.cmake +++ b/cpp/cmake/thirdparty/get_cuco.cmake @@ -20,7 +20,7 @@ function(find_and_configure_cuco VERSION) GLOBAL_TARGETS cuco cuco::cuco CPM_ARGS GIT_REPOSITORY https://github.com/NVIDIA/cuCollections.git - GIT_TAG 0b672bbde7c85a79df4d7ca5f82e15e5b4a57700 + GIT_TAG e5e2abe55152608ef449ecf162a1ef52ded19801 OPTIONS "BUILD_TESTS OFF" "BUILD_BENCHMARKS OFF" "BUILD_EXAMPLES OFF" diff --git a/cpp/include/raft/linalg/distance_type.h b/cpp/include/raft/linalg/distance_type.h index f3a22a07ed..681a83f3f8 100644 --- a/cpp/include/raft/linalg/distance_type.h +++ b/cpp/include/raft/linalg/distance_type.h @@ -52,10 +52,16 @@ enum DistanceType : unsigned short { Haversine = 13, /** Bray-Curtis distance **/ BrayCurtis = 14, - /** Jensen-Shannon distance **/ + /** Jensen-Shannon distance**/ JensenShannon = 15, + /** Hamming distance **/ + HammingUnexpanded = 16, + /** KLDivergence **/ + KLDivergence = 17, + /** RusselRao **/ + RusselRaoExpanded = 18, /** Dice-Sorensen distance **/ - DiceExpanded = 16, + DiceExpanded = 19, /** Precomputed (special value) **/ Precomputed = 100 }; diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/bin_distance.cuh index ae5cbdf9d3..e5ac85e909 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/bin_distance.cuh @@ -85,7 +85,6 @@ void compute_bin_distance(value_t *out, const value_idx *Q_coo_rows, const value_t *Q_data, value_idx Q_nnz, const value_idx *R_coo_rows, const value_t *R_data, value_idx R_nnz, value_idx m, value_idx n, - cusparseHandle_t handle, std::shared_ptr alloc, cudaStream_t stream, expansion_f expansion_func) { raft::mr::device::buffer Q_norms(alloc, stream, m); @@ -114,7 +113,8 @@ class jaccard_expanded_distances_t : public distances_t { explicit jaccard_expanded_distances_t( const distances_config_t &config) : config_(&config), - workspace(config.allocator, config.stream, 0), + workspace(config.handle.get_device_allocator(), + config.handle.get_stream(), 0), ip_dists(config) {} void compute(value_t *out_dists) { @@ -124,15 +124,16 @@ class jaccard_expanded_distances_t : public distances_t { value_t *b_data = ip_dists.b_data_coo(); raft::mr::device::buffer search_coo_rows( - config_->allocator, config_->stream, config_->a_nnz); + config_->handle.get_device_allocator(), config_->handle.get_stream(), + config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, - config_->stream); + config_->handle.get_stream()); compute_bin_distance( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, - config_->handle, config_->allocator, config_->stream, + config_->handle.get_device_allocator(), config_->handle.get_stream(), [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { value_t q_r_union = q_norm + r_norm; value_t denom = q_r_union - dot; @@ -163,7 +164,8 @@ class dice_expanded_distances_t : public distances_t { explicit dice_expanded_distances_t( const distances_config_t &config) : config_(&config), - workspace(config.allocator, config.stream, 0), + workspace(config.handle.get_device_allocator(), + config.handle.get_stream(), 0), ip_dists(config) {} void compute(value_t *out_dists) { @@ -173,15 +175,16 @@ class dice_expanded_distances_t : public distances_t { value_t *b_data = ip_dists.b_data_coo(); raft::mr::device::buffer search_coo_rows( - config_->allocator, config_->stream, config_->a_nnz); + config_->handle.get_device_allocator(), config_->handle.get_stream(), + config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, - config_->stream); + config_->handle.get_stream()); compute_bin_distance( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, - config_->handle, config_->allocator, config_->stream, + config_->handle.get_device_allocator(), config_->handle.get_stream(), [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { value_t q_r_union = q_norm + r_norm; value_t dice = (2 * dot) / q_r_union; diff --git a/cpp/include/raft/sparse/distance/common.h b/cpp/include/raft/sparse/distance/common.h index 712d2c52bd..36982c36ff 100644 --- a/cpp/include/raft/sparse/distance/common.h +++ b/cpp/include/raft/sparse/distance/common.h @@ -16,8 +16,7 @@ #pragma once -#include -#include +#include namespace raft { namespace sparse { @@ -25,6 +24,8 @@ namespace distance { template struct distances_config_t { + distances_config_t(raft::handle_t &handle_) : handle(handle_) {} + // left side value_idx a_nrows; value_idx a_ncols; @@ -41,10 +42,7 @@ struct distances_config_t { value_idx *b_indices; value_t *b_data; - cusparseHandle_t handle; - - std::shared_ptr allocator; - cudaStream_t stream; + raft::handle_t &handle; }; template diff --git a/cpp/include/raft/sparse/distance/coo_spmv.cuh b/cpp/include/raft/sparse/distance/coo_spmv.cuh index d596c6b852..3a78f9ada0 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/coo_spmv.cuh @@ -16,15 +16,18 @@ #pragma once +#include "coo_spmv_strategies/dense_smem_strategy.cuh" +#include "coo_spmv_strategies/hash_strategy.cuh" + #include #include #include #include #include -#include -#include -#include +#include "../csr.cuh" +#include "../utils.h" +#include "common.h" #include @@ -32,189 +35,26 @@ #include -#include +#include namespace raft { namespace sparse { namespace distance { -/** - * Load-balanced sparse-matrix-sparse-matrix multiplication (SPMM) kernel with - * sparse-matrix-sparse-vector multiplication layout (SPMV). - * This is intended to be scheduled n_chunks_b times for each row of a. - * The steps are as follows: - * - * 1. Load row from A into dense vector in shared memory. - * This can be further chunked in the future if necessary to support larger - * column sizes. - * 2. Threads of block all step through chunks of B in parallel. - * When a new row is encountered in row_indices_b, a segmented - * reduction is performed across the warps and then across the - * block and the final value written out to host memory. - * - * Reference: https://www.icl.utk.edu/files/publications/2020/icl-utk-1421-2020.pdf - * - * @tparam value_idx index type - * @tparam value_t value type - * @tparam tpb threads per block configured on launch - * @tparam rev if this is true, the reduce/accumulate functions are only - * executed when A[col] == 0.0. when executed before/after !rev - * and A & B are reversed, this allows the full symmetric difference - * and intersection to be computed. - * @tparam kv_t data type stored in shared mem cache - * @tparam product_f reduce function type (semiring product() function). - * accepts two arguments of value_t and returns a value_t - * @tparam accum_f accumulation function type (semiring sum() function). - * accepts two arguments of value_t and returns a value_t - * @tparam write_f function to write value out. this should be mathematically - * equivalent to the accumulate function but implemented as - * an atomic operation on global memory. Accepts two arguments - * of value_t* and value_t and updates the value given by the - * pointer. - * @param[in] indptrA column pointer array for A - * @param[in] indicesA column indices array for A - * @param[in] dataA data array for A - * @param[in] rowsB coo row array for B - * @param[in] indicesB column indices array for B - * @param[in] dataB data array for B - * @param[in] m number of rows in A - * @param[in] n number of rows in B - * @param[in] dim number of features - * @param[in] nnz_b number of nonzeros in B - * @param[out] out array of size m*n - * @param[in] n_blocks_per_row number of blocks of B per row of A - * @param[in] chunk_size number of nnz for B to use for each row of A - * @param[in] buffer_size amount of smem to use for each row of A - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - * @param[in] write_func atomic semiring sum() function - */ -template -__global__ void balanced_coo_generalized_spmv_kernel( - value_idx *indptrA, value_idx *indicesA, value_t *dataA, value_idx *rowsB, - value_idx *indicesB, value_t *dataB, value_idx m, value_idx n, value_idx dim, - value_idx nnz_b, value_t *out, int n_blocks_per_row, int chunk_size, - product_f product_func, accum_f accum_func, write_f write_func) { - typedef cub::WarpReduce warp_reduce; - - value_idx cur_row_a = blockIdx.x / n_blocks_per_row; - value_idx cur_chunk_offset = blockIdx.x % n_blocks_per_row; - - // chunk starting offset - value_idx ind_offset = cur_chunk_offset * chunk_size * tpb; - // how many total cols will be processed by this block (should be <= chunk_size * n_threads) - value_idx active_chunk_size = min(chunk_size * tpb, nnz_b - ind_offset); - - int tid = threadIdx.x; - int warp_id = tid / raft::warp_size(); - - // compute id relative to current warp - unsigned int lane_id = tid & (raft::warp_size() - 1); - value_idx ind = ind_offset + threadIdx.x; - - extern __shared__ char smem[]; - - value_idx *offsets_a = (value_idx *)smem; - kv_t *A = (kv_t *)(offsets_a + 2); - typename warp_reduce::TempStorage *temp_storage = - (typename warp_reduce::TempStorage *)(A + dim); - - // Create dense vector A and populate with 0s - for (int k = tid; k < dim; k += blockDim.x) A[k] = 0; - - if (tid == 0) { - offsets_a[0] = indptrA[cur_row_a]; - offsets_a[1] = indptrA[cur_row_a + 1]; - } - - __syncthreads(); - - value_idx start_offset_a = offsets_a[0]; - value_idx stop_offset_a = offsets_a[1]; - - // Convert current row vector in A to dense - for (int i = tid; i < (stop_offset_a - start_offset_a); i += blockDim.x) { - A[indicesA[start_offset_a + i]] = dataA[start_offset_a + i]; - } - - __syncthreads(); - - if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return; - if (ind >= nnz_b) return; - - value_idx cur_row_b = -1; - value_t c = 0.0; - - auto warp_red = warp_reduce(*(temp_storage + warp_id)); - - // coalesced reads from B - if (tid < active_chunk_size) { - cur_row_b = rowsB[ind]; - value_t a_col = A[indicesB[ind]]; - if (!rev || a_col == 0.0) c = product_func(a_col, dataB[ind]); - } - - // loop through chunks in parallel, reducing when a new row is - // encountered by each thread - for (int i = tid; i < active_chunk_size; i += blockDim.x) { - value_idx ind_next = ind + blockDim.x; - value_idx next_row_b = -1; - - if (i + blockDim.x < active_chunk_size) next_row_b = rowsB[ind_next]; - - bool diff_rows = next_row_b != cur_row_b; - - if (__any_sync(0xffffffff, diff_rows)) { - // grab the threads currently participating in loops. - // because any other threads should have returned already. - unsigned int peer_group = __match_any_sync(0xffffffff, cur_row_b); - bool is_leader = get_lowest_peer(peer_group) == lane_id; - value_t v = warp_red.HeadSegmentedReduce(c, is_leader, accum_func); - - // thread with lowest lane id among peers writes out - if (is_leader && v != 0.0) { - // this conditional should be uniform, since rev is constant - size_t idx = !rev ? (size_t)cur_row_a * n + cur_row_b - : (size_t)cur_row_b * m + cur_row_a; - write_func(out + idx, v); - } - - c = 0.0; - } - - if (next_row_b != -1) { - ind = ind_next; - value_t a_col = A[indicesB[ind]]; - if (!rev || a_col == 0.0) - c = accum_func(c, product_func(a_col, dataB[ind])); - cur_row_b = next_row_b; - } - } -} - -/** - * Computes the maximum number of columns that can be stored - * in shared memory in dense form with the given block size - * and precision. - * @return the maximum number of columns that can be stored in smem - */ -template -inline int max_cols_per_block() { - // max cols = (total smem available - offsets for A - cub reduction smem) - return (raft::getSharedMemPerBlock() - (2 * sizeof(value_idx)) - - ((tpb / raft::warp_size()) * sizeof(value_t))) / - sizeof(value_t); -} +template +inline void balanced_coo_pairwise_generalized_spmv( + value_t *out_dists, const distances_config_t &config_, + value_idx *coo_rows_b, product_f product_func, accum_f accum_func, + write_f write_func, strategy_t strategy, int chunk_size = 500000) { + CUDA_CHECK(cudaMemsetAsync( + out_dists, 0, sizeof(value_t) * config_.a_nrows * config_.b_nrows, + config_.handle.get_stream())); -template -inline int smem_per_block(int n_cols) { - int max_cols = max_cols_per_block(); - ASSERT(n_cols <= max_cols, "COO SPMV Requires max dimensionality of %d", - max_cols); - return (n_cols * sizeof(value_t)) + (2 * sizeof(value_idx)) + - ((tpb / raft::warp_size()) * sizeof(value_t)); -} + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, + chunk_size); +}; /** * Performs generalized sparse-matrix-sparse-matrix multiplication via a @@ -235,9 +75,6 @@ inline int smem_per_block(int n_cols) { * @tparam value_idx index type * @tparam value_t value type * @tparam threads_per_block block size - * @tparam chunk_size number of nonzeros of B to process for each row of A - * this value was found through profiling and represents a reasonable - * setting for both large and small densities * @tparam product_f semiring product() function * @tparam accum_f semiring sum() function * @tparam write_f atomic semiring sum() function @@ -248,37 +85,43 @@ inline int smem_per_block(int n_cols) { * @param[in] product_func semiring product() function * @param[in] accum_func semiring sum() function * @param[in] write_func atomic semiring sum() function + * @param[in] chunk_size number of nonzeros of B to process for each row of A + * this value was found through profiling and represents a reasonable + * setting for both large and small densities */ template + typename product_f, typename accum_f, typename write_f> inline void balanced_coo_pairwise_generalized_spmv( value_t *out_dists, const distances_config_t &config_, value_idx *coo_rows_b, product_f product_func, accum_f accum_func, - write_f write_func) { + write_f write_func, int chunk_size = 500000) { CUDA_CHECK(cudaMemsetAsync( out_dists, 0, sizeof(value_t) * config_.a_nrows * config_.b_nrows, - config_.stream)); - int n_blocks_per_row = - raft::ceildiv(config_.b_nnz, chunk_size * threads_per_block); - int n_blocks = config_.a_nrows * n_blocks_per_row; - - int smem = - smem_per_block(config_.a_ncols); - - CUDA_CHECK(cudaFuncSetCacheConfig( - balanced_coo_generalized_spmv_kernel, - cudaFuncCachePreferShared)); + config_.handle.get_stream())); + + int max_cols = max_cols_per_block(); + + if (max_cols > config_.a_ncols) { + dense_smem_strategy strategy( + config_); + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, + write_func, chunk_size); + } else { + hash_strategy strategy(config_); + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, + write_func, chunk_size); + } +}; - balanced_coo_generalized_spmv_kernel - <<>>( - config_.a_indptr, config_.a_indices, config_.a_data, coo_rows_b, - config_.b_indices, config_.b_data, config_.a_nrows, config_.b_nrows, - config_.b_ncols, config_.b_nnz, out_dists, n_blocks_per_row, chunk_size, - product_func, accum_func, write_func); +template +inline void balanced_coo_pairwise_generalized_spmv_rev( + value_t *out_dists, const distances_config_t &config_, + value_idx *coo_rows_a, product_f product_func, accum_f accum_func, + write_f write_func, strategy_t strategy, int chunk_size = 500000) { + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, + write_func, chunk_size); }; /** @@ -304,9 +147,6 @@ inline void balanced_coo_pairwise_generalized_spmv( * @tparam value_idx index type * @tparam value_t value type * @tparam threads_per_block block size - * @tparam chunk_size number of nonzeros of B to process for each row of A - * this value was found through profiling and represents a reasonable - * setting for both large and small densities * @tparam product_f semiring product() function * @tparam accum_f semiring sum() function * @tparam write_f atomic semiring sum() function @@ -316,35 +156,31 @@ inline void balanced_coo_pairwise_generalized_spmv( * @param[in] product_func semiring product() function * @param[in] accum_func semiring sum() function * @param[in] write_func atomic semiring sum() function + * @param[in] chunk_size number of nonzeros of B to process for each row of A + * this value was found through profiling and represents a reasonable + * setting for both large and small densities */ template + typename product_f, typename accum_f, typename write_f> inline void balanced_coo_pairwise_generalized_spmv_rev( value_t *out_dists, const distances_config_t &config_, value_idx *coo_rows_a, product_f product_func, accum_f accum_func, - write_f write_func) { - int n_blocks_per_row = - raft::ceildiv(config_.a_nnz, chunk_size * threads_per_block); - int n_blocks = config_.b_nrows * n_blocks_per_row; - - int smem = - smem_per_block(config_.a_ncols); - - CUDA_CHECK(cudaFuncSetCacheConfig( - balanced_coo_generalized_spmv_kernel, - cudaFuncCachePreferShared)); - - balanced_coo_generalized_spmv_kernel - <<>>( - config_.b_indptr, config_.b_indices, config_.b_data, coo_rows_a, - config_.a_indices, config_.a_data, config_.b_nrows, config_.a_nrows, - config_.a_ncols, config_.a_nnz, out_dists, n_blocks_per_row, chunk_size, - product_func, accum_func, write_func); + write_f write_func, int chunk_size = 500000) { + // try dense first + int max_cols = max_cols_per_block(); + + if (max_cols > config_.b_ncols) { + dense_smem_strategy strategy( + config_); + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, + write_func, chunk_size); + } else { + hash_strategy strategy(config_); + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, + write_func, chunk_size); + } }; + } // namespace distance } // namespace sparse }; // namespace raft diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh new file mode 100644 index 0000000000..5ace978a23 --- /dev/null +++ b/cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../common.h" +#include "../detail/coo_spmv_kernel.cuh" +#include "../utils.cuh" +#include "coo_mask_row_iterators.cuh" + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace distance { + +template +class coo_spmv_strategy { + public: + coo_spmv_strategy(const distances_config_t &config_) + : config(config_) { + smem = raft::getSharedMemPerBlock(); + } + + template + void _dispatch_base(strategy_t &strategy, int smem_dim, indptr_it &a_indptr, + value_t *out_dists, value_idx *coo_rows_b, + product_f product_func, accum_f accum_func, + write_f write_func, int chunk_size, int n_blocks, + int n_blocks_per_row) { + CUDA_CHECK(cudaFuncSetCacheConfig( + balanced_coo_generalized_spmv_kernel, + cudaFuncCachePreferShared)); + + balanced_coo_generalized_spmv_kernel + <<>>( + strategy, a_indptr, config.a_indices, config.a_data, config.a_nnz, + coo_rows_b, config.b_indices, config.b_data, config.a_nrows, + config.b_nrows, smem_dim, config.b_nnz, out_dists, n_blocks_per_row, + chunk_size, config.b_ncols, product_func, accum_func, write_func); + } + + template + void _dispatch_base_rev(strategy_t &strategy, int smem_dim, + indptr_it &b_indptr, value_t *out_dists, + value_idx *coo_rows_a, product_f product_func, + accum_f accum_func, write_f write_func, + int chunk_size, int n_blocks, int n_blocks_per_row) { + CUDA_CHECK(cudaFuncSetCacheConfig( + balanced_coo_generalized_spmv_kernel, + cudaFuncCachePreferShared)); + + balanced_coo_generalized_spmv_kernel + <<>>( + strategy, b_indptr, config.b_indices, config.b_data, config.b_nnz, + coo_rows_a, config.a_indices, config.a_data, config.b_nrows, + config.a_nrows, smem_dim, config.a_nnz, out_dists, n_blocks_per_row, + chunk_size, config.a_ncols, product_func, accum_func, write_func); + } + + protected: + int smem; + const distances_config_t &config; +}; + +} // namespace distance +} // namespace sparse +} // namespace raft diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh new file mode 100644 index 0000000000..44c3833f96 --- /dev/null +++ b/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../common.h" +#include "../utils.cuh" + +#include +#include + +namespace raft { +namespace sparse { +namespace distance { + +template +class mask_row_it { + public: + mask_row_it(const value_idx *full_indptr_, const value_idx &n_rows_, + value_idx *mask_row_idx_ = NULL) + : full_indptr(full_indptr_), mask_row_idx(mask_row_idx_), n_rows(n_rows_) {} + + __device__ inline value_idx get_row_idx(const int &n_blocks_nnz_b) { + if (mask_row_idx != NULL) { + return mask_row_idx[blockIdx.x / n_blocks_nnz_b]; + } else { + return blockIdx.x / n_blocks_nnz_b; + } + } + + __device__ inline void get_row_offsets( + const value_idx &row_idx, value_idx &start_offset, value_idx &stop_offset, + const value_idx &n_blocks_nnz_b, bool &first_a_chunk, bool &last_a_chunk) { + start_offset = full_indptr[row_idx]; + stop_offset = full_indptr[row_idx + 1] - 1; + } + + __device__ constexpr inline void get_indices_boundary( + const value_idx *indices, value_idx &indices_len, value_idx &start_offset, + value_idx &stop_offset, value_idx &start_index, value_idx &stop_index, + bool &first_a_chunk, bool &last_a_chunk) { + // do nothing; + } + + __device__ constexpr inline bool check_indices_bounds( + value_idx &start_index_a, value_idx &stop_index_a, value_idx &index_b) { + return true; + } + + const value_idx *full_indptr, &n_rows; + value_idx *mask_row_idx; +}; + +template +__global__ void fill_chunk_indices_kernel(value_idx *n_chunks_per_row, + value_idx *chunk_indices, + value_idx n_rows) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < n_rows) { + auto start = n_chunks_per_row[tid]; + auto end = n_chunks_per_row[tid + 1]; + +#pragma unroll + for (int i = start; i < end; i++) { + chunk_indices[i] = tid; + } + } +} + +template +class chunked_mask_row_it : public mask_row_it { + public: + chunked_mask_row_it(const value_idx *full_indptr_, const value_idx &n_rows_, + value_idx *mask_row_idx_, int row_chunk_size_, + const value_idx *n_chunks_per_row_, + const value_idx *chunk_indices_, + const cudaStream_t stream_) + : mask_row_it(full_indptr_, n_rows_, mask_row_idx_), + row_chunk_size(row_chunk_size_), + n_chunks_per_row(n_chunks_per_row_), + chunk_indices(chunk_indices_), + stream(stream_) {} + + static void init(const value_idx *indptr, const value_idx *mask_row_idx, + const value_idx &n_rows, const int row_chunk_size, + rmm::device_uvector &n_chunks_per_row, + rmm::device_uvector &chunk_indices, + cudaStream_t stream) { + auto policy = rmm::exec_policy(stream); + + constexpr value_idx first_element = 0; + n_chunks_per_row.set_element_async(0, first_element, stream); + n_chunks_per_row_functor chunk_functor(indptr, row_chunk_size); + thrust::transform(policy, mask_row_idx, mask_row_idx + n_rows, + n_chunks_per_row.begin() + 1, chunk_functor); + + thrust::inclusive_scan(policy, n_chunks_per_row.begin() + 1, + n_chunks_per_row.end(), + n_chunks_per_row.begin() + 1); + + raft::update_host(&total_row_blocks, n_chunks_per_row.data() + n_rows, 1, + stream); + + fill_chunk_indices(n_rows, n_chunks_per_row, chunk_indices, stream); + } + + __device__ inline value_idx get_row_idx(const int &n_blocks_nnz_b) { + return this->mask_row_idx[chunk_indices[blockIdx.x / n_blocks_nnz_b]]; + } + + __device__ inline void get_row_offsets( + const value_idx &row_idx, value_idx &start_offset, value_idx &stop_offset, + const int &n_blocks_nnz_b, bool &first_a_chunk, bool &last_a_chunk) { + auto chunk_index = blockIdx.x / n_blocks_nnz_b; + auto chunk_val = chunk_indices[chunk_index]; + auto prev_n_chunks = n_chunks_per_row[chunk_val]; + auto relative_chunk = chunk_index - prev_n_chunks; + first_a_chunk = relative_chunk == 0; + + start_offset = this->full_indptr[row_idx] + relative_chunk * row_chunk_size; + stop_offset = start_offset + row_chunk_size; + + auto final_stop_offset = this->full_indptr[row_idx + 1]; + + last_a_chunk = stop_offset >= final_stop_offset; + stop_offset = last_a_chunk ? final_stop_offset - 1 : stop_offset - 1; + } + + __device__ inline void get_indices_boundary( + const value_idx *indices, value_idx &row_idx, value_idx &start_offset, + value_idx &stop_offset, value_idx &start_index, value_idx &stop_index, + bool &first_a_chunk, bool &last_a_chunk) { + start_index = first_a_chunk ? start_index : indices[start_offset - 1] + 1; + stop_index = last_a_chunk ? stop_index : indices[stop_offset]; + } + + __device__ inline bool check_indices_bounds(value_idx &start_index_a, + value_idx &stop_index_a, + value_idx &index_b) { + return (index_b >= start_index_a && index_b <= stop_index_a); + } + + inline static value_idx total_row_blocks = 0; + const cudaStream_t stream; + const value_idx *n_chunks_per_row, *chunk_indices; + value_idx row_chunk_size; + + struct n_chunks_per_row_functor { + public: + n_chunks_per_row_functor(const value_idx *indptr_, + value_idx row_chunk_size_) + : indptr(indptr_), row_chunk_size(row_chunk_size_) {} + + __host__ __device__ value_idx operator()(const value_idx &i) { + auto degree = indptr[i + 1] - indptr[i]; + return raft::ceildiv(degree, (value_idx)row_chunk_size); + } + + const value_idx *indptr; + value_idx row_chunk_size; + }; + + private: + static void fill_chunk_indices( + const value_idx &n_rows, rmm::device_uvector &n_chunks_per_row, + rmm::device_uvector &chunk_indices, cudaStream_t stream) { + auto n_threads = std::min(n_rows, 256); + auto n_blocks = raft::ceildiv(n_rows, (value_idx)n_threads); + + chunk_indices.resize(total_row_blocks, stream); + + fill_chunk_indices_kernel<<>>( + n_chunks_per_row.data(), chunk_indices.data(), n_rows); + } +}; + +} // namespace distance +} // namespace sparse +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh new file mode 100644 index 0000000000..c463654a3b --- /dev/null +++ b/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base_strategy.cuh" + +namespace raft { +namespace sparse { +namespace distance { + +template +class dense_smem_strategy : public coo_spmv_strategy { + public: + using smem_type = value_t *; + using insert_type = smem_type; + using find_type = smem_type; + + dense_smem_strategy(const distances_config_t &config_) + : coo_spmv_strategy(config_) {} + + inline static int smem_per_block(int n_cols) { + return (n_cols * sizeof(value_t)) + + ((1024 / raft::warp_size()) * sizeof(value_t)); + } + + template + void dispatch(value_t *out_dists, value_idx *coo_rows_b, + product_f product_func, accum_f accum_func, write_f write_func, + int chunk_size) { + auto n_blocks_per_row = + raft::ceildiv(this->config.b_nnz, chunk_size * 1024); + auto n_blocks = this->config.a_nrows * n_blocks_per_row; + + mask_row_it a_indptr(this->config.a_indptr, + this->config.a_nrows); + + this->_dispatch_base(*this, this->config.b_ncols, a_indptr, out_dists, + coo_rows_b, product_func, accum_func, write_func, + chunk_size, n_blocks, n_blocks_per_row); + } + + template + void dispatch_rev(value_t *out_dists, value_idx *coo_rows_a, + product_f product_func, accum_f accum_func, + write_f write_func, int chunk_size) { + auto n_blocks_per_row = + raft::ceildiv(this->config.a_nnz, chunk_size * 1024); + auto n_blocks = this->config.b_nrows * n_blocks_per_row; + + mask_row_it b_indptr(this->config.b_indptr, + this->config.b_nrows); + + this->_dispatch_base_rev(*this, this->config.a_ncols, b_indptr, out_dists, + coo_rows_a, product_func, accum_func, write_func, + chunk_size, n_blocks, n_blocks_per_row); + } + + __device__ inline insert_type init_insert(smem_type cache, + const value_idx &cache_size) { + for (int k = threadIdx.x; k < cache_size; k += blockDim.x) { + cache[k] = 0.0; + } + return cache; + } + + __device__ inline void insert(insert_type cache, const value_idx &key, + const value_t &value) { + cache[key] = value; + } + + __device__ inline find_type init_find(smem_type cache, + const value_idx &cache_size) { + return cache; + } + + __device__ inline value_t find(find_type cache, const value_idx &key) { + return cache[key]; + } +}; + +} // namespace distance +} // namespace sparse +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh new file mode 100644 index 0000000000..1295d24103 --- /dev/null +++ b/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base_strategy.cuh" + +#include + +// this is needed by cuco as key, value must be bitwise comparable. +// compilers don't declare float/double as bitwise comparable +// but that is too strict +// for example, the following is true (or 0): +// float a = 5; +// float b = 5; +// memcmp(&a, &b, sizeof(float)); +CUCO_DECLARE_BITWISE_COMPARABLE(float); +CUCO_DECLARE_BITWISE_COMPARABLE(double); + +namespace raft { +namespace sparse { +namespace distance { + +template +class hash_strategy : public coo_spmv_strategy { + public: + using insert_type = + typename cuco::static_map::device_mutable_view; + using smem_type = typename insert_type::slot_type *; + using find_type = + typename cuco::static_map::device_view; + + hash_strategy(const distances_config_t &config_, + float capacity_threshold_ = 0.5, int map_size_ = get_map_size()) + : coo_spmv_strategy(config_), + capacity_threshold(capacity_threshold_), + map_size(map_size_) {} + + void chunking_needed(const value_idx *indptr, const value_idx n_rows, + rmm::device_uvector &mask_indptr, + std::tuple &n_rows_divided, + cudaStream_t stream) { + auto policy = rmm::exec_policy(stream); + + auto less = thrust::copy_if( + policy, thrust::make_counting_iterator(value_idx(0)), + thrust::make_counting_iterator(n_rows), mask_indptr.data(), + fits_in_hash_table(indptr, 0, capacity_threshold * map_size)); + std::get<0>(n_rows_divided) = less - mask_indptr.data(); + + auto more = thrust::copy_if( + policy, thrust::make_counting_iterator(value_idx(0)), + thrust::make_counting_iterator(n_rows), less, + fits_in_hash_table(indptr, capacity_threshold * map_size, + std::numeric_limits::max())); + std::get<1>(n_rows_divided) = more - less; + } + + template + void dispatch(value_t *out_dists, value_idx *coo_rows_b, + product_f product_func, accum_f accum_func, write_f write_func, + int chunk_size) { + auto n_blocks_per_row = raft::ceildiv(this->config.b_nnz, chunk_size * tpb); + rmm::device_uvector mask_indptr( + this->config.a_nrows, this->config.handle.get_stream()); + std::tuple n_rows_divided; + + chunking_needed(this->config.a_indptr, this->config.a_nrows, mask_indptr, + n_rows_divided, this->config.handle.get_stream()); + + auto less_rows = std::get<0>(n_rows_divided); + if (less_rows > 0) { + mask_row_it less(this->config.a_indptr, less_rows, + mask_indptr.data()); + + auto n_less_blocks = less_rows * n_blocks_per_row; + this->_dispatch_base(*this, map_size, less, out_dists, coo_rows_b, + product_func, accum_func, write_func, chunk_size, + n_less_blocks, n_blocks_per_row); + } + + auto more_rows = std::get<1>(n_rows_divided); + if (more_rows > 0) { + rmm::device_uvector n_chunks_per_row( + more_rows + 1, this->config.handle.get_stream()); + rmm::device_uvector chunk_indices( + 0, this->config.handle.get_stream()); + chunked_mask_row_it::init( + this->config.a_indptr, mask_indptr.data() + less_rows, more_rows, + capacity_threshold * map_size, n_chunks_per_row, chunk_indices, + this->config.handle.get_stream()); + + chunked_mask_row_it more( + this->config.a_indptr, more_rows, mask_indptr.data() + less_rows, + capacity_threshold * map_size, n_chunks_per_row.data(), + chunk_indices.data(), this->config.handle.get_stream()); + + auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; + this->_dispatch_base(*this, map_size, more, out_dists, coo_rows_b, + product_func, accum_func, write_func, chunk_size, + n_more_blocks, n_blocks_per_row); + } + } + + template + void dispatch_rev(value_t *out_dists, value_idx *coo_rows_a, + product_f product_func, accum_f accum_func, + write_f write_func, int chunk_size) { + auto n_blocks_per_row = raft::ceildiv(this->config.a_nnz, chunk_size * tpb); + rmm::device_uvector mask_indptr( + this->config.b_nrows, this->config.handle.get_stream()); + std::tuple n_rows_divided; + + chunking_needed(this->config.b_indptr, this->config.b_nrows, mask_indptr, + n_rows_divided, this->config.handle.get_stream()); + + auto less_rows = std::get<0>(n_rows_divided); + if (less_rows > 0) { + mask_row_it less(this->config.b_indptr, less_rows, + mask_indptr.data()); + + auto n_less_blocks = less_rows * n_blocks_per_row; + this->_dispatch_base_rev(*this, map_size, less, out_dists, coo_rows_a, + product_func, accum_func, write_func, chunk_size, + n_less_blocks, n_blocks_per_row); + } + + auto more_rows = std::get<1>(n_rows_divided); + if (more_rows > 0) { + rmm::device_uvector n_chunks_per_row( + more_rows + 1, this->config.handle.get_stream()); + rmm::device_uvector chunk_indices( + 0, this->config.handle.get_stream()); + chunked_mask_row_it::init( + this->config.b_indptr, mask_indptr.data() + less_rows, more_rows, + capacity_threshold * map_size, n_chunks_per_row, chunk_indices, + this->config.handle.get_stream()); + + chunked_mask_row_it more( + this->config.b_indptr, more_rows, mask_indptr.data() + less_rows, + capacity_threshold * map_size, n_chunks_per_row.data(), + chunk_indices.data(), this->config.handle.get_stream()); + + auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; + this->_dispatch_base_rev(*this, map_size, more, out_dists, coo_rows_a, + product_func, accum_func, write_func, chunk_size, + n_more_blocks, n_blocks_per_row); + } + } + + __device__ inline insert_type init_insert(smem_type cache, + const value_idx &cache_size) { + return insert_type::make_from_uninitialized_slots( + cooperative_groups::this_thread_block(), cache, cache_size, -1, 0); + } + + __device__ inline void insert(insert_type cache, const value_idx &key, + const value_t &value) { + auto success = cache.insert(cuco::pair(key, value)); + } + + __device__ inline find_type init_find(smem_type cache, + const value_idx &cache_size) { + return find_type(cache, cache_size, -1, 0); + } + + __device__ inline value_t find(find_type cache, const value_idx &key) { + auto a_pair = cache.find(key); + + value_t a_col = 0.0; + if (a_pair != cache.end()) { + a_col = a_pair->second; + } + return a_col; + } + + struct fits_in_hash_table { + public: + fits_in_hash_table(const value_idx *indptr_, value_idx degree_l_, + value_idx degree_r_) + : indptr(indptr_), degree_l(degree_l_), degree_r(degree_r_) {} + + __host__ __device__ bool operator()(const value_idx &i) { + auto degree = indptr[i + 1] - indptr[i]; + + return degree >= degree_l && degree < degree_r; + } + + private: + const value_idx *indptr; + const value_idx degree_l, degree_r; + }; + + inline static int get_map_size() { + return (raft::getSharedMemPerBlock() - + ((tpb / raft::warp_size()) * sizeof(value_t))) / + sizeof(typename insert_type::slot_type); + } + + private: + float capacity_threshold; + int map_size; +}; + +} // namespace distance +} // namespace sparse +} // namespace raft diff --git a/cpp/include/raft/sparse/distance/csr_spmv.cuh b/cpp/include/raft/sparse/distance/csr_spmv.cuh deleted file mode 100644 index cd8ca09913..0000000000 --- a/cpp/include/raft/sparse/distance/csr_spmv.cuh +++ /dev/null @@ -1,486 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include - -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { - -/** - * Semiring which schedules each row of B in a different thread. - * @tparam value_idx - * @tparam value_t - * @tparam tpb - * @tparam buffer_size - * @tparam rows_per_block - */ -template -struct BlockSemiring { - __device__ inline BlockSemiring(value_idx n_, value_idx *shared_cols_, - value_t *shared_vals_, value_idx *offsets_a_) - : n(n_), - a_cols(shared_cols_), - a_vals(shared_vals_), - offsets_a(offsets_a_), - done(false), - a_idx(0), - b_row_count(0), - cur_sum(0.0) {} - - /** - * Load columns for a single row of A into shared memory - * @param row - * @param indptrA - * @param indicesA - * @param dataA - */ - __device__ inline void load_a_shared(value_idx row, value_idx *indptrA, - value_idx *indicesA, value_t *dataA) { - if (threadIdx.x == 0) { - offsets_a[0] = indptrA[row]; - offsets_a[1] = indptrA[row + 1]; - } - __syncthreads(); - - value_idx start_offset_a = offsets_a[0]; - value_idx stop_offset_a = offsets_a[1]; - - a_size = stop_offset_a - start_offset_a; - - // Coalesce reads of row from matrix A into shared memory - for (int i = threadIdx.x; i < a_size; i += blockDim.x) { - a_cols[i] = indicesA[start_offset_a + i]; - a_vals[i] = dataA[start_offset_a + i]; - } - - __syncthreads(); - - row_a = row; - } - - /** - * Sets the head for A's pointers so they can be - * iterated in each thread. This is used for the - * case when the maximum degree of any row in A - * is too large to fit into shared memory, so we - * default to increasing the size of the L1 cache - * and suffering the uncoalesced memory accesses - * for both A and B. - * @param row - * @param indptrA - * @param indicesA - * @param dataA - */ - __device__ inline void load_a(value_idx row, value_idx *indptrA, - value_idx *indicesA, value_t *dataA) { - offsets_a[0] = indptrA[row]; - offsets_a[1] = indptrA[row + 1]; - - value_idx start_offset_a = offsets_a[0]; - value_idx stop_offset_a = offsets_a[1]; - - a_size = stop_offset_a - start_offset_a; - - a_cols = indicesA + start_offset_a; - a_vals = dataA + start_offset_a; - - row_a = row; - } - - /** - * Prepare index & offsets for looping through rows of B - * @param start_row - * @param indptrB - */ - __device__ inline void load_b(value_idx start_row, value_idx *indptrB) { - done = false; - a_idx = 0; - cur_sum = 0.0; - - value_idx start_row_b = start_row; - value_idx stop_row_b = min(start_row_b + tpb, n); - - n_rows_b = stop_row_b - start_row_b; - - if (threadIdx.x < n_rows_b) { - row_b = start_row_b + threadIdx.x; - value_idx start_offset_b = indptrB[row_b]; - b_row_count = indptrB[row_b + 1] - start_offset_b; - b_idx = start_offset_b; - b_idx_stop = start_offset_b + b_row_count; - } - } - - /** - * Perform single single column intersection/union for A & B - * based on the row of A mapped to shared memory and the row - * of B mapped to current thread. - * @param product_func - * @param accum_func - */ - __device__ inline void step(value_idx *b_cols, value_t *b_vals, - product_f product_func, accum_f accum_func) { - if (threadIdx.x < n_rows_b) { - bool local_idx_in_bounds = b_idx < b_idx_stop && b_row_count > 0; - - value_idx b = local_idx_in_bounds ? b_cols[b_idx] : -1; - value_t bv = local_idx_in_bounds ? b_vals[b_idx] : 0.0; - - bool a_idx_in_bounds = a_idx < a_size; - - value_idx a = a_idx_in_bounds ? a_cols[a_idx] : -1; - value_t av = a_idx_in_bounds ? a_vals[a_idx] : 0.0; - - bool run_b = ((b <= a && b != -1) || (b != -1 && a == -1)); - b_idx += 1 * run_b; - value_t b_side = bv * run_b; - - bool run_a = ((a <= b && a != -1) || (b == -1 && a != -1)); - a_idx += 1 * run_a; - value_t a_side = av * run_a; - - // Apply semiring "sum" & "product" functions locally - cur_sum = accum_func(cur_sum, product_func(b_side, a_side)); - - // finished when all items in chunk have been - // processed - done = b == -1 && a == -1; - - } else { - done = true; - } - } - - __device__ inline bool isdone() { return done; } - - __device__ inline void write(value_t *out) { - if (threadIdx.x < n_rows_b) { - out[(size_t)row_a * n + row_b] = cur_sum; - } - } - - private: - bool done; - - int a_size; - - value_idx n_rows_b; - - value_idx b_idx; - value_idx b_idx_stop; - value_idx a_idx; - - value_t cur_sum; - - value_idx n; - - value_idx row_a; - value_idx row_b; - - value_idx *offsets_a; - - // shared memory - value_idx b_row_count; - value_idx *a_cols; - value_t *a_vals; -}; - -/** - * Optimized for large numbers of rows but small enough numbers of columns - * that each thread can process their rows in parallel. - * @tparam value_idx index type - * @tparam value_t value type - * @tparam tpb block size - * @tparam product_f semiring product() function - * @tparam accum_f semiring sum() function - * @param[in] indptrA csr column index pointer array for A - * @param[in] indicesA csr column indices array for A - * @param[in] dataA csr data array for A - * @param[in] indptrB csr column index pointer array for B - * @param[in] indicesB csr column indices array for B - * @param[in] dataB csr data array for B - * @param[in] m number of rows in A - * @param[in] n number of rows in B - * @param[out] out dense output array of size m * n in row-major layout - * @param[in] n_blocks_per_row number of blocks of B scheduled per row of A - * @param[in] n_rows_per_block number of rows of A scheduled per block of B - * @param[in] buffer_size number of nonzeros to store in smem - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - */ -template -__global__ void classic_csr_semiring_spmv_smem_kernel( - value_idx *indptrA, value_idx *indicesA, value_t *dataA, value_idx *indptrB, - value_idx *indicesB, value_t *dataB, value_idx m, value_idx n, value_t *out, - int n_blocks_per_row, int n_rows_per_block, int buffer_size, - product_f product_func, accum_f accum_func) { - value_idx out_row = blockIdx.x / n_blocks_per_row; - value_idx out_col_start = blockIdx.x % n_blocks_per_row; - - value_idx row_b_start = out_col_start * n_rows_per_block; - - extern __shared__ char smem[]; - - value_idx *offsets_a = (value_idx *)smem; - value_idx *a_cols = offsets_a + 2; - buffer_size += - (sizeof(value_idx) != sizeof(value_t)) && ((buffer_size % 2) != 0); - value_t *a_vals = (value_t *)(a_cols + buffer_size); - - BlockSemiring semiring( - n, a_cols, a_vals, offsets_a); - - semiring.load_a_shared(out_row, indptrA, indicesA, dataA); - - if (out_row > m || row_b_start > n) return; - - // for each batch, parallelize the resulting rows across threads - for (int i = 0; i < n_rows_per_block; i += blockDim.x) { - semiring.load_b(row_b_start + i, indptrB); - do { - semiring.step(indicesB, dataB, product_func, accum_func); - } while (!semiring.isdone()); - - semiring.write(out); - } -} - -template -__global__ void classic_csr_semiring_spmv_kernel( - value_idx *indptrA, value_idx *indicesA, value_t *dataA, value_idx *indptrB, - value_idx *indicesB, value_t *dataB, value_idx m, value_idx n, value_t *out, - int n_blocks_per_row, int n_rows_per_block, product_f product_func, - accum_f accum_func) { - value_idx out_row = blockIdx.x / n_blocks_per_row; - value_idx out_col_start = blockIdx.x % n_blocks_per_row; - - value_idx row_b_start = out_col_start * n_rows_per_block; - - value_idx offsets_a[2]; - - BlockSemiring semiring( - n, indicesA, dataA, offsets_a); - - semiring.load_a(out_row, indptrA, indicesA, dataA); - - if (out_row > m || row_b_start > n) return; - - // for each batch, parallel the resulting rows across threads - for (int i = 0; i < n_rows_per_block; i += blockDim.x) { - semiring.load_b(row_b_start + i, indptrB); - do { - semiring.step(indicesB, dataB, product_func, accum_func); - } while (!semiring.isdone()); - - semiring.write(out); - } -} - -/** - * Compute the maximum number of nonzeros that can be stored in shared - * memory per block with the given index and value precision - * @return max nnz that can be stored in smem per block - */ -template -inline value_idx max_nnz_per_block() { - // max nnz = total smem - offsets for A - // (division because we need to store cols & vals separately) - return (raft::getSharedMemPerBlock() - (2 * sizeof(value_idx))) / - (sizeof(value_t) + sizeof(value_idx)); -} - -/** - * @tparam value_idx - * @param out - * @param in - * @param n - */ -template -__global__ void max_kernel(value_idx *out, value_idx *in, value_idx n) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - value_idx v = tid < n ? in[tid] - in[tid - 1] : 0; - value_idx agg = BlockReduce(temp_storage).Reduce(v, cub::Max()); - - if (threadIdx.x == 0) atomicMax(out, agg); -} - -template -inline value_idx max_degree( - value_idx *indptr, value_idx n_rows, - std::shared_ptr allocator, cudaStream_t stream) { - raft::mr::device::buffer max_d(allocator, stream, 1); - CUDA_CHECK(cudaMemsetAsync(max_d.data(), 0, sizeof(value_idx), stream)); - - /** - * A custom max reduction is performed until https://github.com/rapidsai/cuml/issues/3431 - * is fixed. - */ - max_kernel<<>>( - max_d.data(), indptr + 1, n_rows); - - value_idx max_h; - raft::update_host(&max_h, max_d.data(), 1, stream); - - CUDA_CHECK(cudaStreamSynchronize(stream)); - - return max_h; -} - -template -void _generalized_csr_pairwise_semiring( - value_t *out_dists, const distances_config_t &config_, - product_f product_func, accum_f accum_func) { - int n_chunks = 1; - int n_rows_per_block = min(n_chunks * threads_per_block, config_.b_nrows); - int n_blocks_per_row = raft::ceildiv(config_.b_nrows, n_rows_per_block); - int n_blocks = config_.a_nrows * n_blocks_per_row; - - CUDA_CHECK(cudaFuncSetCacheConfig( - classic_csr_semiring_spmv_kernel, - cudaFuncCachePreferL1)); - - classic_csr_semiring_spmv_kernel - <<>>( - config_.a_indptr, config_.a_indices, config_.a_data, config_.b_indptr, - config_.b_indices, config_.b_data, config_.a_nrows, config_.b_nrows, - out_dists, n_blocks_per_row, n_rows_per_block, product_func, accum_func); -}; - -template -void _generalized_csr_pairwise_smem_semiring( - value_t *out_dists, const distances_config_t &config_, - product_f product_func, accum_f accum_func, value_idx max_nnz) { - int n_chunks = 10000; - int n_rows_per_block = min(n_chunks * threads_per_block, config_.b_nrows); - int n_blocks_per_row = raft::ceildiv(config_.b_nrows, n_rows_per_block); - int n_blocks = config_.a_nrows * n_blocks_per_row; - - // TODO: Figure out why performance is worse with smaller smem sizes - int smem_size = raft::getSharedMemPerBlock(); - - CUDA_CHECK(cudaFuncSetCacheConfig( - classic_csr_semiring_spmv_smem_kernel, - cudaFuncCachePreferShared)); - - classic_csr_semiring_spmv_smem_kernel - <<>>( - config_.a_indptr, config_.a_indices, config_.a_data, config_.b_indptr, - config_.b_indices, config_.b_data, config_.a_nrows, config_.b_nrows, - out_dists, n_blocks_per_row, n_rows_per_block, max_nnz, product_func, - accum_func); -} - -/** - * Perform generalized sparse-matrix-sparse-vector multiply in - * a semiring algebra by allowing the product and sum operations - * to be defined. This approach saves the most memory as it can - * work directly on a CSR w/o the need for conversion to another - * sparse format, does not require any transposition, nor loading - * any vectors in dense form. The major drawback to this kernel - * is that the non-uniform memory access pattern dominates performance. - * When the shared memory option is used, bank conflicts also dominate - * performance, making it slower than other options but guaranteeing - * that the product() operation will be executed across every column - * in A and B. - * - * This is primarily useful when in cases where the product() operation - * is non-anniliating (e.g. product(x, 0) = x. - * - * There are two potential code paths for this primitive- if the largest - * degree of any row is small enough to fit in shared memory then shared - * memory is used to coalesce the reads from the vectors of A, otherwise - * no shared memory is used and all loads from A and B happen independently - * in separate threads. - * - * Iterators are maintained for the vectors from both A and B and each - * thread iterates to a maximum of |a|+|b| (which will happen only when - * the set of columns for vectors a and b are completely disjoint. - * - * TODO: Some potential things to try for future optimizations: - * - Always iterating for n_cols so that each warp is iterating - * a uniform number of times. - * - Computing an argsort() of B based on the number of columns - * in each row to attempt to load balance the warps naturally - * - Finding a way to coalesce the reads - * - * Ref: https://github.com/rapidsai/cuml/issues/3371 - * - * @tparam value_idx index type - * @tparam value_t value type - * @tparam product_f semiring product() function - * @tparam accum_f semiring sum() function - * @param[out] out_dists dense array of output distances size m * n in row-major layout - * @param[in] config_ distance config object - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - */ -template -void generalized_csr_pairwise_semiring( - value_t *out_dists, const distances_config_t &config_, - product_f product_func, accum_f accum_func) { - int nnz_upper_bound = max_nnz_per_block(); - - // max_nnz set from max(diff(indptrA)) - value_idx max_nnz = max_degree(config_.a_indptr, config_.a_nrows, - config_.allocator, config_.stream) + - 1; - - if (max_nnz <= nnz_upper_bound) - // use smem - _generalized_csr_pairwise_smem_semiring( - out_dists, config_, product_func, accum_func, max_nnz); - - else - // load each row of A separately - _generalized_csr_pairwise_semiring( - out_dists, config_, product_func, accum_func); -}; - -} // namespace distance -} // namespace sparse -}; // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh new file mode 100644 index 0000000000..51f9a05394 --- /dev/null +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace distance { + +/** + * Load-balanced sparse-matrix-sparse-matrix multiplication (SPMM) kernel with + * sparse-matrix-sparse-vector multiplication layout (SPMV). + * This is intended to be scheduled n_chunks_b times for each row of a. + * The steps are as follows: + * + * 1. Load row from A into dense vector in shared memory. + * This can be further chunked in the future if necessary to support larger + * column sizes. + * 2. Threads of block all step through chunks of B in parallel. + * When a new row is encountered in row_indices_b, a segmented + * reduction is performed across the warps and then across the + * block and the final value written out to host memory. + * + * Reference: https://www.icl.utk.edu/files/publications/2020/icl-utk-1421-2020.pdf + * + * @tparam value_idx index type + * @tparam value_t value type + * @tparam tpb threads per block configured on launch + * @tparam rev if this is true, the reduce/accumulate functions are only + * executed when A[col] == 0.0. when executed before/after !rev + * and A & B are reversed, this allows the full symmetric difference + * and intersection to be computed. + * @tparam kv_t data type stored in shared mem cache + * @tparam product_f reduce function type (semiring product() function). + * accepts two arguments of value_t and returns a value_t + * @tparam accum_f accumulation function type (semiring sum() function). + * accepts two arguments of value_t and returns a value_t + * @tparam write_f function to write value out. this should be mathematically + * equivalent to the accumulate function but implemented as + * an atomic operation on global memory. Accepts two arguments + * of value_t* and value_t and updates the value given by the + * pointer. + * @param[in] indptrA column pointer array for A + * @param[in] indicesA column indices array for A + * @param[in] dataA data array for A + * @param[in] rowsB coo row array for B + * @param[in] indicesB column indices array for B + * @param[in] dataB data array for B + * @param[in] m number of rows in A + * @param[in] n number of rows in B + * @param[in] dim number of features + * @param[in] nnz_b number of nonzeros in B + * @param[out] out array of size m*n + * @param[in] n_blocks_per_row number of blocks of B per row of A + * @param[in] chunk_size number of nnz for B to use for each row of A + * @param[in] buffer_size amount of smem to use for each row of A + * @param[in] product_func semiring product() function + * @param[in] accum_func semiring sum() function + * @param[in] write_func atomic semiring sum() function + */ +template +__global__ void balanced_coo_generalized_spmv_kernel( + strategy_t strategy, indptr_it indptrA, value_idx *indicesA, value_t *dataA, + value_idx nnz_a, value_idx *rowsB, value_idx *indicesB, value_t *dataB, + value_idx m, value_idx n, int dim, value_idx nnz_b, value_t *out, + int n_blocks_per_row, int chunk_size, value_idx b_ncols, + product_f product_func, accum_f accum_func, write_f write_func) { + typedef cub::WarpReduce warp_reduce; + + value_idx cur_row_a = indptrA.get_row_idx(n_blocks_per_row); + value_idx cur_chunk_offset = blockIdx.x % n_blocks_per_row; + + // chunk starting offset + value_idx ind_offset = cur_chunk_offset * chunk_size * tpb; + // how many total cols will be processed by this block (should be <= chunk_size * n_threads) + value_idx active_chunk_size = min(chunk_size * tpb, nnz_b - ind_offset); + + int tid = threadIdx.x; + int warp_id = tid / raft::warp_size(); + + // compute id relative to current warp + unsigned int lane_id = tid & (raft::warp_size() - 1); + value_idx ind = ind_offset + threadIdx.x; + + extern __shared__ char smem[]; + + typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem); + typename warp_reduce::TempStorage *temp_storage = + (typename warp_reduce::TempStorage *)(A + dim); + + auto inserter = strategy.init_insert(A, dim); + + __syncthreads(); + + value_idx start_offset_a, stop_offset_a; + bool first_a_chunk, last_a_chunk; + indptrA.get_row_offsets(cur_row_a, start_offset_a, stop_offset_a, + n_blocks_per_row, first_a_chunk, last_a_chunk); + + // Convert current row vector in A to dense + for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) { + strategy.insert(inserter, indicesA[start_offset_a + i], + dataA[start_offset_a + i]); + } + + __syncthreads(); + + auto finder = strategy.init_find(A, dim); + + if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return; + if (ind >= nnz_b) return; + + value_idx start_index_a = 0, stop_index_a = b_ncols - 1; + indptrA.get_indices_boundary(indicesA, cur_row_a, start_offset_a, + stop_offset_a, start_index_a, stop_index_a, + first_a_chunk, last_a_chunk); + + value_idx cur_row_b = -1; + value_t c = 0.0; + + auto warp_red = warp_reduce(*(temp_storage + warp_id)); + + if (tid < active_chunk_size) { + cur_row_b = rowsB[ind]; + + auto index_b = indicesB[ind]; + auto in_bounds = + indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); + + if (in_bounds) { + value_t a_col = strategy.find(finder, index_b); + if (!rev || a_col == 0.0) { + c = product_func(a_col, dataB[ind]); + } + } + } + + // loop through chunks in parallel, reducing when a new row is + // encountered by each thread + for (int i = tid; i < active_chunk_size; i += blockDim.x) { + value_idx ind_next = ind + blockDim.x; + value_idx next_row_b = -1; + + if (i + blockDim.x < active_chunk_size) next_row_b = rowsB[ind_next]; + + bool diff_rows = next_row_b != cur_row_b; + + if (__any_sync(0xffffffff, diff_rows)) { + // grab the threads currently participating in loops. + // because any other threads should have returned already. + unsigned int peer_group = __match_any_sync(0xffffffff, cur_row_b); + bool is_leader = get_lowest_peer(peer_group) == lane_id; + value_t v = warp_red.HeadSegmentedReduce(c, is_leader, accum_func); + + // thread with lowest lane id among peers writes out + if (is_leader && v != 0.0) { + // this conditional should be uniform, since rev is constant + size_t idx = !rev ? (size_t)cur_row_a * n + cur_row_b + : (size_t)cur_row_b * m + cur_row_a; + write_func(out + idx, v); + } + + c = 0.0; + } + + if (next_row_b != -1) { + ind = ind_next; + + auto index_b = indicesB[ind]; + auto in_bounds = + indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); + if (in_bounds) { + value_t a_col = strategy.find(finder, index_b); + + if (!rev || a_col == 0.0) { + c = accum_func(c, product_func(a_col, dataB[ind])); + } + } + + cur_row_b = next_row_b; + } + } +} + +} // namespace distance +} // namespace sparse +} // namespace raft diff --git a/cpp/include/raft/sparse/distance/ip_distance.cuh b/cpp/include/raft/sparse/distance/ip_distance.cuh index 90717bfc5f..bf45fe0d8e 100644 --- a/cpp/include/raft/sparse/distance/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/ip_distance.cuh @@ -35,210 +35,24 @@ #include -#include - namespace raft { namespace sparse { namespace distance { -/** - * A simple interface that enables different instances - * of inner product. Currently, there are two implementations: - * cusparse gemm and our own semiring spmv. - * @tparam value_idx - * @tparam value_t - */ -template -class ip_trans_getters_t : public distances_t { - public: - /** - * A copy of B's data in coo format. This is - * useful for downstream distances that - * might be able to compute a norm instead of - * point-wise products. - * @return - */ - virtual value_t *b_data_coo() = 0; - - /** - * A copy of B's rows in coo format. This is - * useful for downstream distances that - * might be able to compute a norm instead of - * point-wise products. - * @return - */ - virtual value_idx *b_rows_coo() = 0; - - virtual ~ip_trans_getters_t() = default; -}; - -/** - * Simple inner product distance with sparse matrix multiply. This - * uses cusparse and requires both B to be transposed as well as - * the output to be explicitly converted to dense form (which requires - * 3 copies of the dense data- 2 for the cusparse csr output and - * 1 for the final m*n dense matrix.) - */ template -class ip_distances_gemm_t : public ip_trans_getters_t { - public: - /** - * Computes simple sparse inner product distances as sum(x_y * y_k) - * @param[in] config specifies inputs, outputs, and sizes - * - * TODO: Remove this once we have a semiring SPGEMM - * Ref: https://github.com/rapidsai/cuml/issues/3371 - */ - explicit ip_distances_gemm_t( - const distances_config_t &config) - : config_(&config), - workspace(config.allocator, config.stream, 0), - csc_indptr(config.allocator, config.stream, 0), - csc_indices(config.allocator, config.stream, 0), - csc_data(config.allocator, config.stream, 0), - alpha(1.0) { - init_mat_descriptor(matA); - init_mat_descriptor(matB); - init_mat_descriptor(matC); - init_mat_descriptor(matD); - - CUSPARSE_CHECK(cusparseCreateCsrgemm2Info(&info)); - - CUSPARSE_CHECK(cusparseGetPointerMode(config.handle, &orig_ptr_mode)); - - CUSPARSE_CHECK( - cusparseSetPointerMode(config.handle, CUSPARSE_POINTER_MODE_HOST)); - } - - /** - * Performs pairwise distance computation and computes output distances - * @param out_distances dense output matrix (size a_nrows * b_nrows) - */ - void compute(value_t *out_distances) { - /** - * Compute pairwise distances and return dense matrix in column-major format - */ - raft::mr::device::buffer out_batch_indptr( - config_->allocator, config_->stream, config_->a_nrows + 1); - raft::mr::device::buffer out_batch_indices(config_->allocator, - config_->stream, 0); - raft::mr::device::buffer out_batch_data(config_->allocator, - config_->stream, 0); - - value_idx out_batch_nnz = get_nnz(out_batch_indptr.data()); - - out_batch_indices.resize(out_batch_nnz, config_->stream); - out_batch_data.resize(out_batch_nnz, config_->stream); - - compute_gemm(out_batch_indptr.data(), out_batch_indices.data(), - out_batch_data.data()); - - raft::sparse::convert::csr_to_dense( - config_->handle, config_->a_nrows, config_->b_nrows, - out_batch_indptr.data(), out_batch_indices.data(), out_batch_data.data(), - config_->a_nrows, out_distances, config_->stream, true); - } - - virtual value_idx *b_rows_coo() { return csc_indices.data(); } - - value_t *b_data_coo() { return csc_data.data(); } - - ~ip_distances_gemm_t() { - CUSPARSE_CHECK_NO_THROW(cusparseDestroyMatDescr(matA)); - CUSPARSE_CHECK_NO_THROW(cusparseDestroyMatDescr(matB)); - CUSPARSE_CHECK_NO_THROW(cusparseDestroyMatDescr(matC)); - CUSPARSE_CHECK_NO_THROW(cusparseDestroyMatDescr(matD)); - - CUSPARSE_CHECK_NO_THROW( - cusparseSetPointerMode(config_->handle, orig_ptr_mode)); - } - - private: - void init_mat_descriptor(cusparseMatDescr_t &mat) { - CUSPARSE_CHECK(cusparseCreateMatDescr(&mat)); - CUSPARSE_CHECK(cusparseSetMatIndexBase(mat, CUSPARSE_INDEX_BASE_ZERO)); - CUSPARSE_CHECK(cusparseSetMatType(mat, CUSPARSE_MATRIX_TYPE_GENERAL)); - } - - value_idx get_nnz(value_idx *csr_out_indptr) { - value_idx m = config_->a_nrows, n = config_->b_nrows, k = config_->a_ncols; - - transpose_b(); - - size_t workspace_size; - - CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2_buffersizeext( - config_->handle, m, n, k, &alpha, NULL, matA, config_->a_nnz, - config_->a_indptr, config_->a_indices, matB, config_->b_nnz, - csc_indptr.data(), csc_indices.data(), matD, 0, NULL, NULL, info, - &workspace_size, config_->stream)); - - workspace.resize(workspace_size, config_->stream); - - value_idx out_nnz = 0; - - CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2nnz( - config_->handle, m, n, k, matA, config_->a_nnz, config_->a_indptr, - config_->a_indices, matB, config_->b_nnz, csc_indptr.data(), - csc_indices.data(), matD, 0, NULL, NULL, matC, csr_out_indptr, &out_nnz, - info, workspace.data(), config_->stream)); - - return out_nnz; - } - - void compute_gemm(const value_idx *csr_out_indptr, value_idx *csr_out_indices, - value_t *csr_out_data) { - value_idx m = config_->a_nrows, n = config_->b_nrows, k = config_->a_ncols; - - CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2( - config_->handle, m, n, k, &alpha, matA, config_->a_nnz, config_->a_data, - config_->a_indptr, config_->a_indices, matB, config_->b_nnz, - csc_data.data(), csc_indptr.data(), csc_indices.data(), NULL, matD, 0, - NULL, NULL, NULL, matC, csr_out_data, csr_out_indptr, csr_out_indices, - info, workspace.data(), config_->stream)); - } - - void transpose_b() { - /** - * Transpose index array into csc - */ - csc_indptr.resize(config_->b_ncols + 1, config_->stream); - csc_indices.resize(config_->b_nnz, config_->stream); - csc_data.resize(config_->b_nnz, config_->stream); - - raft::sparse::linalg::csr_transpose( - config_->handle, config_->b_indptr, config_->b_indices, config_->b_data, - csc_indptr.data(), csc_indices.data(), csc_data.data(), config_->b_nrows, - config_->b_ncols, config_->b_nnz, config_->allocator, config_->stream); - } - - value_t alpha; - csrgemm2Info_t info; - cusparseMatDescr_t matA; - cusparseMatDescr_t matB; - cusparseMatDescr_t matC; - cusparseMatDescr_t matD; - cusparsePointerMode_t orig_ptr_mode; - raft::mr::device::buffer workspace; - raft::mr::device::buffer csc_indptr; - raft::mr::device::buffer csc_indices; - raft::mr::device::buffer csc_data; - const distances_config_t *config_; -}; - -template -class ip_distances_spmv_t : public ip_trans_getters_t { +class ip_distances_t : public distances_t { public: /** * Computes simple sparse inner product distances as sum(x_y * y_k) * @param[in] config specifies inputs, outputs, and sizes */ - ip_distances_spmv_t(const distances_config_t &config) + ip_distances_t(const distances_config_t &config) : config_(&config), - coo_rows_b(config.allocator, config.stream, config.b_nnz) { + coo_rows_b(config.handle.get_device_allocator(), + config.handle.get_stream(), config.b_nnz) { raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, coo_rows_b.data(), config_->b_nnz, - config_->stream); + config_->handle.get_stream()); } /** @@ -258,66 +72,10 @@ class ip_distances_spmv_t : public ip_trans_getters_t { value_t *b_data_coo() { return config_->b_data; } - ~ip_distances_spmv_t() = default; - private: const distances_config_t *config_; raft::mr::device::buffer coo_rows_b; }; - -template -class ip_distances_t : public distances_t { - public: - /** - * Computes simple sparse inner product distances as sum(x_y * y_k) - * @param[in] config specifies inputs, outputs, and sizes - */ - explicit ip_distances_t(const distances_config_t &config) - : config_(&config) { - if (config_->a_ncols < max_cols_per_block()) { - internal_ip_dist = - std::make_unique>(*config_); - } else { - internal_ip_dist = - std::make_unique>(*config_); - } - } - - /** - * Performs pairwise distance computation and computes output distances - * @param out_distances dense output matrix (size a_nrows * b_nrows) - */ - void compute(value_t *out_distances) { - /** - * Compute pairwise distances and return dense matrix in column-major format - */ - internal_ip_dist->compute(out_distances); - } - - virtual value_idx *b_rows_coo() const { - return internal_ip_dist->b_rows_coo(); - } - - virtual value_t *b_data_coo() const { return internal_ip_dist->b_data_coo(); } - - private: - const distances_config_t *config_; - std::unique_ptr> internal_ip_dist; -}; - -/** - * Compute pairwise distances between A and B, using the provided - * input configuration and distance function. - * - * @tparam value_idx index type - * @tparam value_t value type - * @param[out] out dense output array (size A.nrows * B.nrows) - * @param[in] input_config input argument configuration - * @param[in] metric distance metric to use - */ -template class ip_distances_t; -template class distances_config_t; - }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 829471e0e3..f73e23d94b 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -50,6 +50,17 @@ __global__ void compute_row_norm_kernel(value_t *out, } } +template +__global__ void compute_row_sum_kernel(value_t *out, + const value_idx *__restrict__ coo_rows, + const value_t *__restrict__ data, + value_idx nnz) { + value_idx i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < nnz) { + atomicAdd(&out[coo_rows[i]], data[i]); + } +} + template __global__ void compute_euclidean_warp_kernel( value_t *__restrict__ C, const value_t *__restrict__ Q_sq_norms, @@ -67,12 +78,10 @@ __global__ void compute_euclidean_warp_kernel( value_t val = expansion_func(dot, Q_sq_norms[i], R_sq_norms[j]); // correct for small instabilities - if (fabs(val) < 0.0001) val = 0.0; - - C[(size_t)i * n_cols + j] = val; + C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); } -template void compute_euclidean(value_t *C, const value_t *Q_sq_norms, const value_t *R_sq_norms, value_idx n_rows, @@ -83,13 +92,12 @@ void compute_euclidean(value_t *C, const value_t *Q_sq_norms, C, Q_sq_norms, R_sq_norms, n_rows, n_cols, expansion_func); } -template void compute_l2(value_t *out, const value_idx *Q_coo_rows, const value_t *Q_data, value_idx Q_nnz, const value_idx *R_coo_rows, const value_t *R_data, value_idx R_nnz, value_idx m, value_idx n, - cusparseHandle_t handle, std::shared_ptr alloc, cudaStream_t stream, expansion_f expansion_func) { raft::mr::device::buffer Q_sq_norms(alloc, stream, m); @@ -117,9 +125,7 @@ class l2_expanded_distances_t : public distances_t { public: explicit l2_expanded_distances_t( const distances_config_t &config) - : config_(&config), - workspace(config.allocator, config.stream, 0), - ip_dists(config) {} + : config_(&config), ip_dists(config) {} void compute(value_t *out_dists) { ip_dists.compute(out_dists); @@ -128,15 +134,16 @@ class l2_expanded_distances_t : public distances_t { value_t *b_data = ip_dists.b_data_coo(); raft::mr::device::buffer search_coo_rows( - config_->allocator, config_->stream, config_->a_nnz); + config_->handle.get_device_allocator(), config_->handle.get_stream(), + config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, - config_->stream); + config_->handle.get_stream()); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, - config_->handle, config_->allocator, config_->stream, + config_->handle.get_device_allocator(), config_->handle.get_stream(), [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { return -2 * dot + q_norm + r_norm; }); @@ -146,7 +153,6 @@ class l2_expanded_distances_t : public distances_t { protected: const distances_config_t *config_; - raft::mr::device::buffer workspace; ip_distances_t ip_dists; }; @@ -171,7 +177,7 @@ class l2_sqrt_expanded_distances_t int neg = input < 0 ? -1 : 1; return sqrt(abs(input) * neg); }, - this->config_->stream); + this->config_->handle.get_stream()); } ~l2_sqrt_expanded_distances_t() = default; @@ -187,7 +193,8 @@ class cosine_expanded_distances_t : public distances_t { explicit cosine_expanded_distances_t( const distances_config_t &config) : config_(&config), - workspace(config.allocator, config.stream, 0), + workspace(config.handle.get_device_allocator(), + config.handle.get_stream(), 0), ip_dists(config) {} void compute(value_t *out_dists) { @@ -197,15 +204,16 @@ class cosine_expanded_distances_t : public distances_t { value_t *b_data = ip_dists.b_data_coo(); raft::mr::device::buffer search_coo_rows( - config_->allocator, config_->stream, config_->a_nnz); + config_->handle.get_device_allocator(), config_->handle.get_stream(), + config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, - config_->stream); + config_->handle.get_stream()); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, - config_->handle, config_->allocator, config_->stream, + config_->handle.get_device_allocator(), config_->handle.get_stream(), [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { value_t norms = sqrt(q_norm) * sqrt(r_norm); // deal with potential for 0 in denominator by forcing 0/1 instead @@ -240,33 +248,22 @@ class hellinger_expanded_distances_t : public distances_t { explicit hellinger_expanded_distances_t( const distances_config_t &config) : config_(&config), - workspace(config.allocator, config.stream, 0), - ip_dists(config) {} + workspace(config.handle.get_device_allocator(), + config.handle.get_stream(), 0) {} void compute(value_t *out_dists) { - // First sqrt A and B - raft::linalg::unaryOp( - config_->a_data, config_->a_data, config_->a_nnz, - [=] __device__(value_t input) { return sqrt(input); }, config_->stream); + raft::mr::device::buffer coo_rows( + config_->handle.get_device_allocator(), config_->handle.get_stream(), + max(config_->b_nnz, config_->a_nnz)); - if (config_->a_data != config_->b_data) { - raft::linalg::unaryOp( - config_->b_data, config_->b_data, config_->b_nnz, - [=] __device__(value_t input) { return sqrt(input); }, config_->stream); - } + raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, + coo_rows.data(), config_->b_nnz, + config_->handle.get_stream()); - ip_dists.compute(out_dists); - - // Revert sqrt of A and B - raft::linalg::unaryOp( - config_->a_data, config_->a_data, config_->a_nnz, - [=] __device__(value_t input) { return input * input; }, config_->stream); - if (config_->a_data != config_->b_data) { - raft::linalg::unaryOp( - config_->b_data, config_->b_data, config_->b_nnz, - [=] __device__(value_t input) { return input * input; }, - config_->stream); - } + balanced_coo_pairwise_generalized_spmv( + out_dists, *config_, coo_rows.data(), + [] __device__(value_t a, value_t b) { return sqrt(a) * sqrt(b); }, Sum(), + AtomicAdd()); raft::linalg::unaryOp( out_dists, out_dists, config_->a_nrows * config_->b_nrows, @@ -275,7 +272,7 @@ class hellinger_expanded_distances_t : public distances_t { bool rectifier = (1 - input) > 0; return sqrt(rectifier * (1 - input)); }, - config_->stream); + config_->handle.get_stream()); } ~hellinger_expanded_distances_t() = default; @@ -283,7 +280,6 @@ class hellinger_expanded_distances_t : public distances_t { private: const distances_config_t *config_; raft::mr::device::buffer workspace; - ip_distances_t ip_dists; }; }; // END namespace distance diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index e524d87b7c..653dc55683 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -31,7 +31,6 @@ #include #include -#include #include #include @@ -42,51 +41,26 @@ namespace distance { template - void unexpanded_lp_distances( value_t *out_dists, const distances_config_t *config_, product_f product_func, accum_f accum_func, write_f write_func) { - /** - * @TODO: Main logic here: - * - * - if n_cols < available smem, just use dense conversion for rows of A - * - if n_cols > available smem but max nnz < available smem, use hashing - * (not yet available) - * - if n_cols > available smem & max_nnz > available smem, - * use batching + hashing only for those large cols - * Ref: https://github.com/rapidsai/cuml/issues/3371 - */ - - if (config_->a_ncols < max_cols_per_block()) { - // TODO: Use n_cols to set shared memory and threads per block - // for max occupancy. - // Ref: https://github.com/rapidsai/cuml/issues/3371 + raft::mr::device::buffer coo_rows( + config_->handle.get_device_allocator(), config_->handle.get_stream(), + max(config_->b_nnz, config_->a_nnz)); - raft::mr::device::buffer coo_rows( - config_->allocator, config_->stream, max(config_->b_nnz, config_->a_nnz)); + raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, + coo_rows.data(), config_->b_nnz, + config_->handle.get_stream()); - raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, - coo_rows.data(), config_->b_nnz, - config_->stream); + balanced_coo_pairwise_generalized_spmv( + out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); - balanced_coo_pairwise_generalized_spmv( - out_dists, *config_, coo_rows.data(), product_func, accum_func, - write_func); + raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, + coo_rows.data(), config_->a_nnz, + config_->handle.get_stream()); - raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, - coo_rows.data(), config_->a_nnz, - config_->stream); - - balanced_coo_pairwise_generalized_spmv_rev( - out_dists, *config_, coo_rows.data(), product_func, accum_func, - write_func); - - } else { - // TODO: Find max nnz and set smem based on this value. - // Ref: https://github.com/rapidsai/cuml/issues/3371 - generalized_csr_pairwise_semiring( - out_dists, *config_, product_func, accum_func); - } + balanced_coo_pairwise_generalized_spmv_rev( + out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); } /** @@ -145,7 +119,7 @@ class l2_sqrt_unexpanded_distances_t int neg = input < 0 ? -1 : 1; return sqrt(abs(input) * neg); }, - this->config_->stream); + this->config_->handle.get_stream()); } }; @@ -204,14 +178,13 @@ class lp_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp( out_dists, out_dists, config_->a_nrows * config_->b_nrows, [=] __device__(value_t input) { return pow(input, one_over_p); }, - config_->stream); + config_->handle.get_stream()); } private: const distances_config_t *config_; value_t p; }; - }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/operators.cuh b/cpp/include/raft/sparse/distance/operators.cuh index d14a42b407..89acda8b1a 100644 --- a/cpp/include/raft/sparse/distance/operators.cuh +++ b/cpp/include/raft/sparse/distance/operators.cuh @@ -29,6 +29,13 @@ struct Sum { } }; +struct NotEqual { + template + __host__ __device__ __forceinline__ value_t operator()(value_t a, value_t b) { + return a != b; + } +}; + struct SqDiff { template __host__ __device__ __forceinline__ value_t operator()(value_t a, value_t b) { diff --git a/cpp/include/raft/sparse/distance/utils.cuh b/cpp/include/raft/sparse/distance/utils.cuh new file mode 100644 index 0000000000..6b6d77a2d5 --- /dev/null +++ b/cpp/include/raft/sparse/distance/utils.cuh @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace raft { +namespace sparse { +namespace distance { + +/** + * Computes the maximum number of columns that can be stored + * in shared memory in dense form with the given block size + * and precision. + * @return the maximum number of columns that can be stored in smem + */ +template +inline int max_cols_per_block() { + // max cols = (total smem available - cub reduction smem) + return (raft::getSharedMemPerBlock() - + ((tpb / raft::warp_size()) * sizeof(value_t))) / + sizeof(value_t); +} + +} // namespace distance +} // namespace sparse +} // namespace raft diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index e327386d13..47df50bbe5 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -120,9 +120,7 @@ class sparse_knn_t { const value_idx *queryIndices_, const value_t *queryData_, size_t queryNNZ_, int n_query_rows_, int n_query_cols_, value_idx *output_indices_, value_t *output_dists_, int k_, - cusparseHandle_t cusparseHandle_, - std::shared_ptr allocator_, - cudaStream_t stream_, + raft::handle_t &handle_, size_t batch_size_index_ = 2 << 14, // approx 1M size_t batch_size_query_ = 2 << 14, raft::distance::DistanceType metric_ = @@ -143,9 +141,7 @@ class sparse_knn_t { output_indices(output_indices_), output_dists(output_dists_), k(k_), - cusparseHandle(cusparseHandle_), - allocator(allocator_), - stream(stream_), + handle(handle_), batch_size_index(batch_size_index_), batch_size_query(batch_size_query_), metric(metric_), @@ -171,22 +167,25 @@ class sparse_knn_t { */ rmm::device_uvector query_batch_indptr( - query_batcher.batch_rows() + 1, stream); + query_batcher.batch_rows() + 1, handle.get_stream()); value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz( - query_batch_indptr.data(), stream); + query_batch_indptr.data(), handle.get_stream()); rmm::device_uvector query_batch_indices(n_query_batch_nnz, - stream); - rmm::device_uvector query_batch_data(n_query_batch_nnz, stream); + handle.get_stream()); + rmm::device_uvector query_batch_data(n_query_batch_nnz, + handle.get_stream()); query_batcher.get_batch_csr_indices_data(query_batch_indices.data(), - query_batch_data.data(), stream); + query_batch_data.data(), + handle.get_stream()); // A 3-partition temporary merge space to scale the batching. 2 parts for subsequent // batches and 1 space for the results of the merge, which get copied back to the top - rmm::device_uvector merge_buffer_indices(0, stream); - rmm::device_uvector merge_buffer_dists(0, stream); + rmm::device_uvector merge_buffer_indices(0, + handle.get_stream()); + rmm::device_uvector merge_buffer_dists(0, handle.get_stream()); value_t *dists_merge_buffer_ptr; value_idx *indices_merge_buffer_ptr; @@ -198,32 +197,36 @@ class sparse_knn_t { for (int j = 0; j < n_batches_idx; j++) { idx_batcher.set_batch(j); - merge_buffer_indices.resize(query_batcher.batch_rows() * k * 3, stream); - merge_buffer_dists.resize(query_batcher.batch_rows() * k * 3, stream); + merge_buffer_indices.resize(query_batcher.batch_rows() * k * 3, + handle.get_stream()); + merge_buffer_dists.resize(query_batcher.batch_rows() * k * 3, + handle.get_stream()); /** * Slice CSR to rows in batch */ rmm::device_uvector idx_batch_indptr( - idx_batcher.batch_rows() + 1, stream); - rmm::device_uvector idx_batch_indices(0, stream); - rmm::device_uvector idx_batch_data(0, stream); + idx_batcher.batch_rows() + 1, handle.get_stream()); + rmm::device_uvector idx_batch_indices(0, + handle.get_stream()); + rmm::device_uvector idx_batch_data(0, handle.get_stream()); - value_idx idx_batch_nnz = - idx_batcher.get_batch_csr_indptr_nnz(idx_batch_indptr.data(), stream); + value_idx idx_batch_nnz = idx_batcher.get_batch_csr_indptr_nnz( + idx_batch_indptr.data(), handle.get_stream()); - idx_batch_indices.resize(idx_batch_nnz, stream); - idx_batch_data.resize(idx_batch_nnz, stream); + idx_batch_indices.resize(idx_batch_nnz, handle.get_stream()); + idx_batch_data.resize(idx_batch_nnz, handle.get_stream()); - idx_batcher.get_batch_csr_indices_data(idx_batch_indices.data(), - idx_batch_data.data(), stream); + idx_batcher.get_batch_csr_indices_data( + idx_batch_indices.data(), idx_batch_data.data(), handle.get_stream()); /** * Compute distances */ size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows(); - rmm::device_uvector batch_dists(dense_size, stream); + rmm::device_uvector batch_dists(dense_size, + handle.get_stream()); CUDA_CHECK(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t))); @@ -236,13 +239,14 @@ class sparse_knn_t { // Build batch indices array rmm::device_uvector batch_indices(batch_dists.size(), - stream); + handle.get_stream()); // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), batch_cols = idx_batcher.batch_rows(); - iota_fill(batch_indices.data(), batch_rows, batch_cols, stream); + iota_fill(batch_indices.data(), batch_rows, batch_cols, + handle.get_stream()); /** * Perform k-selection on batch & merge with other k-selections @@ -276,19 +280,19 @@ class sparse_knn_t { // copy merged output back into merge buffer partition for next iteration raft::copy_async(merge_buffer_indices.data(), indices_merge_buffer_tmp_ptr, - batch_rows * k, stream); + batch_rows * k, handle.get_stream()); raft::copy_async(merge_buffer_dists.data(), dists_merge_buffer_tmp_ptr, batch_rows * k, - stream); + handle.get_stream()); } // Copy final merged batch to output array - raft::copy_async(output_indices + (rows_processed * k), - merge_buffer_indices.data(), - query_batcher.batch_rows() * k, stream); - raft::copy_async(output_dists + (rows_processed * k), - merge_buffer_dists.data(), - query_batcher.batch_rows() * k, stream); + raft::copy_async( + output_indices + (rows_processed * k), merge_buffer_indices.data(), + query_batcher.batch_rows() * k, handle.get_stream()); + raft::copy_async( + output_dists + (rows_processed * k), merge_buffer_dists.data(), + query_batcher.batch_rows() * k, handle.get_stream()); rows_processed += query_batcher.batch_rows(); } @@ -305,14 +309,14 @@ class sparse_knn_t { id_ranges.push_back(0); id_ranges.push_back(idx_batcher.batch_start()); - rmm::device_uvector trans(id_ranges.size(), stream); + rmm::device_uvector trans(id_ranges.size(), handle.get_stream()); raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), - stream); + handle.get_stream()); // combine merge buffers only if there's more than 1 partition to combine raft::spatial::knn::knn_merge_parts( merge_buffer_dists, merge_buffer_indices, out_dists, out_indices, - query_batcher.batch_rows(), 2, k, stream, trans.data()); + query_batcher.batch_rows(), 2, k, handle.get_stream(), trans.data()); } void perform_k_selection(csr_batcher_t idx_batcher, @@ -337,7 +341,7 @@ class sparse_knn_t { // kernel to slice first (min) k cols and copy into batched merge buffer select_k(batch_dists, batch_indices, batch_rows, batch_cols, out_dists, - out_indices, ascending, n_neighbors, stream); + out_indices, ascending, n_neighbors, handle.get_stream()); } void compute_distances(csr_batcher_t &idx_batcher, @@ -351,7 +355,8 @@ class sparse_knn_t { /** * Compute distances */ - raft::sparse::distance::distances_config_t dist_config; + raft::sparse::distance::distances_config_t dist_config( + handle); dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; dist_config.b_nnz = idx_batch_nnz; @@ -368,10 +373,6 @@ class sparse_knn_t { dist_config.a_indices = query_batch_indices; dist_config.a_data = query_batch_data; - dist_config.handle = cusparseHandle; - dist_config.allocator = allocator; - dist_config.stream = stream; - if (raft::sparse::distance::supportedDistance.find(metric) == raft::sparse::distance::supportedDistance.end()) THROW("DistanceType not supported: %d", metric); @@ -393,11 +394,7 @@ class sparse_knn_t { int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; - cusparseHandle_t cusparseHandle; - - std::shared_ptr allocator; - - cudaStream_t stream; + raft::handle_t &handle; }; /** @@ -419,7 +416,7 @@ class sparse_knn_t { * @param[in] k the number of neighbors to query * @param[in] cusparseHandle the initialized cusparseHandle instance to use * @param[in] allocator device allocator instance to use - * @param[in] stream CUDA stream to order operations with respect to + * @param[in] handle.get_stream() CUDA handle.get_stream() to order operations with respect to * @param[in] batch_size_index maximum number of rows to use from index matrix per batch * @param[in] batch_size_query maximum number of rows to use from query matrix per batch * @param[in] metric distance metric/measure to use @@ -432,9 +429,7 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, const value_idx *queryIndices, const value_t *queryData, size_t queryNNZ, int n_query_rows, int n_query_cols, value_idx *output_indices, value_t *output_dists, int k, - cusparseHandle_t cusparseHandle, - std::shared_ptr allocator, - cudaStream_t stream, + raft::handle_t &handle, size_t batch_size_index = 2 << 14, // approx 1M size_t batch_size_query = 2 << 14, raft::distance::DistanceType metric = @@ -443,8 +438,8 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, sparse_knn_t( idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr, queryIndices, queryData, queryNNZ, n_query_rows, n_query_cols, - output_indices, output_dists, k, cusparseHandle, allocator, stream, - batch_size_index, batch_size_query, metric, metricArg) + output_indices, output_dists, k, handle, batch_size_index, batch_size_query, + metric, metricArg) .run(); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6496ac26c6..0dec4be833 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -66,7 +66,6 @@ add_executable(test_raft test/sparse/csr_transpose.cu test/sparse/degree.cu test/sparse/dist_coo_spmv.cu - test/sparse/dist_csr_spmv.cu test/sparse/distance.cu test/sparse/filter.cu test/sparse/knn.cu @@ -119,6 +118,7 @@ PRIVATE CUDA::cudart CUDA::cusparse rmm::rmm + cuco::cuco FAISS::FAISS GTest::gtest GTest::gtest_main diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 6e3f3b5038..a83b93f83f 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -30,6 +30,8 @@ #include "../test_utils.h" +#include + namespace raft { namespace sparse { namespace distance { @@ -38,7 +40,7 @@ using namespace raft; using namespace raft::sparse; template -struct SparseDistanceCOOSPMVInputs { +struct InputConfiguration { value_idx n_cols; std::vector indptr_h; @@ -52,46 +54,73 @@ struct SparseDistanceCOOSPMVInputs { float metric_arg = 0.0; }; -template +using dense_smem_strategy_t = dense_smem_strategy; +using hash_strategy_t = hash_strategy; + +template +struct SparseDistanceCOOSPMVInputs { + InputConfiguration input_configuration; + + float capacity_threshold = 0.5; + int map_size = hash_strategy::get_map_size(); +}; + +template ::std::ostream &operator<<( ::std::ostream &os, - const SparseDistanceCOOSPMVInputs &dims) { + const SparseDistanceCOOSPMVInputs &dims) { return os; } -template +template class SparseDistanceCOOSPMVTest : public ::testing::TestWithParam< - SparseDistanceCOOSPMVInputs> { + SparseDistanceCOOSPMVInputs> { public: + SparseDistanceCOOSPMVTest() : dist_config(handle) {} + + template > * = nullptr> + U make_strategy() { + return strategy_t(dist_config, params.capacity_threshold, params.map_size); + } + + template > * = nullptr> + U make_strategy() { + return strategy_t(dist_config); + } + template void compute_dist(reduce_f reduce_func, accum_f accum_func, write_f write_func, bool rev = true) { raft::mr::device::buffer coo_rows( - dist_config.allocator, dist_config.stream, + dist_config.handle.get_device_allocator(), + dist_config.handle.get_stream(), max(dist_config.b_nnz, dist_config.a_nnz)); raft::sparse::convert::csr_to_coo(dist_config.b_indptr, dist_config.b_nrows, coo_rows.data(), dist_config.b_nnz, - dist_config.stream); + dist_config.handle.get_stream()); + strategy_t selected_strategy = make_strategy(); balanced_coo_pairwise_generalized_spmv( out_dists, dist_config, coo_rows.data(), reduce_func, accum_func, - write_func); + write_func, selected_strategy); if (rev) { - raft::sparse::convert::csr_to_coo(dist_config.a_indptr, - dist_config.a_nrows, coo_rows.data(), - dist_config.a_nnz, dist_config.stream); + raft::sparse::convert::csr_to_coo( + dist_config.a_indptr, dist_config.a_nrows, coo_rows.data(), + dist_config.a_nnz, dist_config.handle.get_stream()); balanced_coo_pairwise_generalized_spmv_rev( out_dists, dist_config, coo_rows.data(), reduce_func, accum_func, - write_func); + write_func, selected_strategy); } } void run_spmv() { - switch (params.metric) { + switch (params.input_configuration.metric) { case raft::distance::DistanceType::InnerProduct: compute_dist(Product(), Sum(), AtomicAdd(), true); break; @@ -112,12 +141,13 @@ class SparseDistanceCOOSPMVTest compute_dist(AbsDiff(), Max(), AtomicMax()); break; case raft::distance::DistanceType::LpUnexpanded: { - compute_dist(PDiff(params.metric_arg), Sum(), AtomicAdd()); - float p = 1.0f / params.metric_arg; + compute_dist(PDiff(params.input_configuration.metric_arg), Sum(), + AtomicAdd()); + float p = 1.0f / params.input_configuration.metric_arg; raft::linalg::unaryOp( out_dists, out_dists, dist_config.a_nrows * dist_config.b_nrows, [=] __device__(value_t input) { return powf(input, p); }, - dist_config.stream); + dist_config.handle.get_stream()); } break; default: @@ -127,53 +157,47 @@ class SparseDistanceCOOSPMVTest protected: void make_data() { - std::vector indptr_h = params.indptr_h; - std::vector indices_h = params.indices_h; - std::vector data_h = params.data_h; + std::vector indptr_h = params.input_configuration.indptr_h; + std::vector indices_h = params.input_configuration.indices_h; + std::vector data_h = params.input_configuration.data_h; allocate(indptr, indptr_h.size()); allocate(indices, indices_h.size()); allocate(data, data_h.size()); - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); - update_device(indices, indices_h.data(), indices_h.size(), stream); - update_device(data, data_h.data(), data_h.size(), stream); + update_device(indptr, indptr_h.data(), indptr_h.size(), + handle.get_stream()); + update_device(indices, indices_h.data(), indices_h.size(), + handle.get_stream()); + update_device(data, data_h.data(), data_h.size(), handle.get_stream()); - std::vector out_dists_ref_h = params.out_dists_ref_h; + std::vector out_dists_ref_h = + params.input_configuration.out_dists_ref_h; allocate(out_dists_ref, (indptr_h.size() - 1) * (indptr_h.size() - 1)); update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), - stream); + handle.get_stream()); } void SetUp() override { params = ::testing::TestWithParam< - SparseDistanceCOOSPMVInputs>::GetParam(); - std::shared_ptr alloc( - new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); - - CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); - CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); + SparseDistanceCOOSPMVInputs>::GetParam(); make_data(); - dist_config.b_nrows = params.indptr_h.size() - 1; - dist_config.b_ncols = params.n_cols; - dist_config.b_nnz = params.indices_h.size(); + dist_config.b_nrows = params.input_configuration.indptr_h.size() - 1; + dist_config.b_ncols = params.input_configuration.n_cols; + dist_config.b_nnz = params.input_configuration.indices_h.size(); dist_config.b_indptr = indptr; dist_config.b_indices = indices; dist_config.b_data = data; - dist_config.a_nrows = params.indptr_h.size() - 1; - dist_config.a_ncols = params.n_cols; - dist_config.a_nnz = params.indices_h.size(); + dist_config.a_nrows = params.input_configuration.indptr_h.size() - 1; + dist_config.a_ncols = params.input_configuration.n_cols; + dist_config.a_nnz = params.input_configuration.indices_h.size(); dist_config.a_indptr = indptr; dist_config.a_indices = indices; dist_config.a_data = data; - dist_config.handle = cusparseHandle; - dist_config.allocator = alloc; - dist_config.stream = stream; int out_size = dist_config.a_nrows * dist_config.b_nrows; @@ -181,11 +205,11 @@ class SparseDistanceCOOSPMVTest run_spmv(); - CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); } void TearDown() override { - CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); CUDA_CHECK(cudaFree(indptr)); CUDA_CHECK(cudaFree(indices)); CUDA_CHECK(cudaFree(data)); @@ -194,18 +218,13 @@ class SparseDistanceCOOSPMVTest } void compare() { - raft::print_device_vector("expected: ", out_dists_ref, - params.out_dists_ref_h.size(), std::cout); - raft::print_device_vector("out_dists: ", out_dists, - params.out_dists_ref_h.size(), std::cout); ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, - params.out_dists_ref_h.size(), + params.input_configuration.out_dists_ref_h.size(), CompareApprox(1e-3))); } protected: - cudaStream_t stream; - cusparseHandle_t cusparseHandle; + raft::handle_t handle; // input data value_idx *indptr, *indices; @@ -216,44 +235,47 @@ class SparseDistanceCOOSPMVTest raft::sparse::distance::distances_config_t dist_config; - SparseDistanceCOOSPMVInputs params; + SparseDistanceCOOSPMVInputs params; }; -const std::vector> inputs_i32_f = { - {2, - {0, 2, 4, 6, 8}, - {0, 1, 0, 1, 0, 1, 0, 1}, - {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, - {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, - 5.0}, - raft::distance::DistanceType::InnerProduct, - 0.0}, - {2, - {0, 2, 4, 6, 8}, - {0, 1, 0, 1, 0, 1, 0, 1}, // indices - {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}, - { - // dense output - 0.0, - 4.0, - 3026.0, - 226.0, - 4.0, - 0.0, - 2930.0, - 234.0, - 3026.0, - 2930.0, - 0.0, - 1832.0, - 226.0, - 234.0, - 1832.0, - 0.0, - }, - raft::distance::DistanceType::L2Unexpanded, - 0.0}, +const InputConfiguration input_inner_product = { + 2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, + {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, + {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, + 5.0}, + raft::distance::DistanceType::InnerProduct, + 0.0}; + +const InputConfiguration input_l2_unexpanded = { + 2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, // indices + {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}, + { + // dense output + 0.0, + 4.0, + 3026.0, + 226.0, + 4.0, + 0.0, + 2930.0, + 234.0, + 3026.0, + 2930.0, + 0.0, + 1832.0, + 226.0, + 234.0, + 1832.0, + 0.0, + }, + raft::distance::DistanceType::L2Unexpanded, + 0.0}; +const InputConfiguration input_canberra = {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, @@ -366,8 +388,9 @@ const std::vector> inputs_i32_f = { 7.0, 0.0}, raft::distance::DistanceType::Canberra, - 0.0}, + 0.0}; +const InputConfiguration input_lp_unexpanded = {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, @@ -480,8 +503,9 @@ const std::vector> inputs_i32_f = { 1.3661374102525012, 0.0}, raft::distance::DistanceType::LpUnexpanded, - 2.0}, + 2.0}; +const InputConfiguration input_linf = {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, @@ -594,40 +618,73 @@ const std::vector> inputs_i32_f = { 0.8429599432532096, 0.0}, raft::distance::DistanceType::Linf, - 0.0}, - - {4, - {0, 1, 1, 2, 4}, - {3, 2, 0, 1}, // indices - {0.99296, 0.42180, 0.11687, 0.305869}, - { - // dense output - 0.0, - 0.99296, - 1.41476, - 1.415707, - 0.99296, - 0.0, - 0.42180, - 0.42274, - 1.41476, - 0.42180, - 0.0, - 0.84454, - 1.41570, - 0.42274, - 0.84454, - 0.0, - }, - raft::distance::DistanceType::L1, - 0.0} - -}; - -typedef SparseDistanceCOOSPMVTest SparseDistanceCOOSPMVTestF; -TEST_P(SparseDistanceCOOSPMVTestF, Result) { compare(); } -INSTANTIATE_TEST_CASE_P(SparseDistanceCOOSPMVTests, SparseDistanceCOOSPMVTestF, - ::testing::ValuesIn(inputs_i32_f)); + 0.0}; + +const InputConfiguration input_l1 = { + 4, + {0, 1, 1, 2, 4}, + {3, 2, 0, 1}, // indices + {0.99296, 0.42180, 0.11687, 0.305869}, + { + // dense output + 0.0, + 0.99296, + 1.41476, + 1.415707, + 0.99296, + 0.0, + 0.42180, + 0.42274, + 1.41476, + 0.42180, + 0.0, + 0.84454, + 1.41570, + 0.42274, + 0.84454, + 0.0, + }, + raft::distance::DistanceType::L1, + 0.0}; + +// test dense smem strategy +const std::vector< + SparseDistanceCOOSPMVInputs> + inputs_dense_strategy = {{input_inner_product}, {input_l2_unexpanded}, + {input_canberra}, {input_lp_unexpanded}, + {input_linf}, {input_l1}}; + +typedef SparseDistanceCOOSPMVTest + SparseDistanceCOOSPMVTestDenseStrategyF; +TEST_P(SparseDistanceCOOSPMVTestDenseStrategyF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseDistanceCOOSPMVTests, + SparseDistanceCOOSPMVTestDenseStrategyF, + ::testing::ValuesIn(inputs_dense_strategy)); + +// test hash and chunk strategy +const std::vector> + inputs_hash_strategy = {{input_inner_product}, + {input_inner_product, 0.5, 2}, + {input_l2_unexpanded}, + {input_l2_unexpanded, 0.5, 2}, + {input_canberra}, + {input_canberra, 0.5, 2}, + {input_canberra, 0.5, 6}, + {input_lp_unexpanded}, + {input_lp_unexpanded, 0.5, 2}, + {input_lp_unexpanded, 0.5, 6}, + {input_linf}, + {input_linf, 0.5, 2}, + {input_linf, 0.5, 6}, + {input_l1}, + {input_l1, 0.5, 2}}; + +typedef SparseDistanceCOOSPMVTest + SparseDistanceCOOSPMVTestHashStrategyF; +TEST_P(SparseDistanceCOOSPMVTestHashStrategyF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseDistanceCOOSPMVTests, + SparseDistanceCOOSPMVTestHashStrategyF, + ::testing::ValuesIn(inputs_hash_strategy)); }; // namespace distance }; // end namespace sparse diff --git a/cpp/test/sparse/dist_csr_spmv.cu b/cpp/test/sparse/dist_csr_spmv.cu deleted file mode 100644 index c32748a04e..0000000000 --- a/cpp/test/sparse/dist_csr_spmv.cu +++ /dev/null @@ -1,612 +0,0 @@ -/* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -#include -#include -#include -#include -#include - -#include -#include - -#include "../test_utils.h" - -namespace raft { -namespace sparse { -namespace distance { - -using namespace raft; -using namespace raft::sparse; - -template -struct SparseDistanceCSRSPMVInputs { - value_idx n_cols; - - std::vector indptr_h; - std::vector indices_h; - std::vector data_h; - - std::vector out_dists_ref_h; - - raft::distance::DistanceType metric; - - float metric_arg = 0.0; -}; - -template -::std::ostream &operator<<( - ::std::ostream &os, - const SparseDistanceCSRSPMVInputs &dims) { - return os; -} - -template -class SparseDistanceCSRSPMVTest - : public ::testing::TestWithParam< - SparseDistanceCSRSPMVInputs> { - public: - template - void compute_dist(reduce_f reduce_func, accum_f accum_func) { - generalized_csr_pairwise_semiring( - out_dists, dist_config, reduce_func, accum_func); - } - - void run_spmv() { - switch (params.metric) { - case raft::distance::DistanceType::InnerProduct: - compute_dist(Product(), Sum()); - break; - case raft::distance::DistanceType::L2Unexpanded: - compute_dist(SqDiff(), Sum()); - break; - case raft::distance::DistanceType::Canberra: - compute_dist( - [] __device__(value_t a, value_t b) { - value_t d = fabsf(a) + fabsf(b); - return ((d != 0) * fabsf(a - b)) / (d + (d == 0)); - }, - Sum()); - break; - case raft::distance::DistanceType::L1: - compute_dist(AbsDiff(), Sum()); - break; - case raft::distance::DistanceType::Linf: - compute_dist(AbsDiff(), Max()); - break; - case raft::distance::DistanceType::LpUnexpanded: { - compute_dist(PDiff(params.metric_arg), Sum()); - float pow = 1.0f / params.metric_arg; - raft::linalg::unaryOp( - out_dists, out_dists, dist_config.a_nrows * dist_config.b_nrows, - [=] __device__(value_t input) { return powf(input, pow); }, - dist_config.stream); - - } break; - default: - throw raft::exception("Unknown distance"); - } - } - - protected: - void make_data() { - std::vector indptr_h = params.indptr_h; - std::vector indices_h = params.indices_h; - std::vector data_h = params.data_h; - - allocate(indptr, indptr_h.size()); - allocate(indices, indices_h.size()); - allocate(data, data_h.size()); - - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); - update_device(indices, indices_h.data(), indices_h.size(), stream); - update_device(data, data_h.data(), data_h.size(), stream); - - std::vector out_dists_ref_h = params.out_dists_ref_h; - - allocate(out_dists_ref, (indptr_h.size() - 1) * (indptr_h.size() - 1)); - - update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), - stream); - } - - void SetUp() override { - params = ::testing::TestWithParam< - SparseDistanceCSRSPMVInputs>::GetParam(); - std::shared_ptr alloc( - new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); - - CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); - - make_data(); - - dist_config.b_nrows = params.indptr_h.size() - 1; - dist_config.b_ncols = params.n_cols; - dist_config.b_nnz = params.indices_h.size(); - dist_config.b_indptr = indptr; - dist_config.b_indices = indices; - dist_config.b_data = data; - dist_config.a_nrows = params.indptr_h.size() - 1; - dist_config.a_ncols = params.n_cols; - dist_config.a_nnz = params.indices_h.size(); - dist_config.a_indptr = indptr; - dist_config.a_indices = indices; - dist_config.a_data = data; - dist_config.handle = cusparseHandle; - dist_config.allocator = alloc; - dist_config.stream = stream; - - int out_size = dist_config.a_nrows * dist_config.b_nrows; - - allocate(out_dists, out_size); - - run_spmv(); - - CUDA_CHECK(cudaStreamSynchronize(stream)); - } - - void TearDown() override { - CUDA_CHECK(cudaStreamSynchronize(stream)); - CUDA_CHECK(cudaFree(indptr)); - CUDA_CHECK(cudaFree(indices)); - CUDA_CHECK(cudaFree(data)); - CUDA_CHECK(cudaFree(out_dists)); - CUDA_CHECK(cudaFree(out_dists_ref)); - } - - void compare() { - raft::print_device_vector("expected: ", out_dists_ref, - params.out_dists_ref_h.size(), std::cout); - raft::print_device_vector("out_dists: ", out_dists, - params.out_dists_ref_h.size(), std::cout); - ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, - params.out_dists_ref_h.size(), - CompareApprox(1e-3))); - } - - protected: - cudaStream_t stream; - cusparseHandle_t cusparseHandle; - - // input data - value_idx *indptr, *indices; - value_t *data; - - // output data - value_t *out_dists, *out_dists_ref; - - raft::sparse::distance::distances_config_t dist_config; - - SparseDistanceCSRSPMVInputs params; -}; - -const std::vector> inputs_i32_f = { - {2, - {0, 2, 4, 6, 8}, - {0, 1, 0, 1, 0, 1, 0, 1}, - {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, - {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, - 5.0}, - raft::distance::DistanceType::InnerProduct}, - {2, - {0, 2, 4, 6, 8}, - {0, 1, 0, 1, 0, 1, 0, 1}, // indices - {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}, - { - // dense output - 0.0, - 4.0, - 3026.0, - 226.0, - 4.0, - 0.0, - 2930.0, - 234.0, - 3026.0, - 2930.0, - 0.0, - 1832.0, - 226.0, - 234.0, - 1832.0, - 0.0, - }, - raft::distance::DistanceType::L2Unexpanded, - 0.0}, - - {10, - {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, - {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, - 3, 4, 7, 0, 1, 2, 3, 4, 6, 8, 0, 1, 2, 5, 7, 1, 5, - 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices - {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, - 0.5167, 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, - 0.8206, 0.3625, 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, - 0.0535, 0.2225, 0.8853, 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, - 0.5279, 0.4885, 0.3495, 0.5079, 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, - 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, - {0.0, - 3.3954660629919076, - 5.6469232737388815, - 6.373112846266441, - 4.0212880272531715, - 6.916281504639404, - 5.741508386786526, - 5.411470999663036, - 9.0, - 4.977014354725805, - 3.3954660629919076, - 0.0, - 7.56256082439209, - 5.540261147481582, - 4.832322929216881, - 4.62003193872216, - 6.498056792320361, - 4.309846252268695, - 6.317531174829905, - 6.016362684141827, - 5.6469232737388815, - 7.56256082439209, - 0.0, - 5.974878731322299, - 4.898357301336036, - 6.442097410320605, - 5.227077347287883, - 7.134101195584642, - 5.457753923371659, - 7.0, - 6.373112846266441, - 5.540261147481582, - 5.974878731322299, - 0.0, - 5.5507273748583, - 4.897749658726415, - 9.0, - 8.398776718824767, - 3.908281400328807, - 4.83431066343688, - 4.0212880272531715, - 4.832322929216881, - 4.898357301336036, - 5.5507273748583, - 0.0, - 6.632989819428174, - 7.438852294822894, - 5.6631570310967465, - 7.579428202635459, - 6.760811985364303, - 6.916281504639404, - 4.62003193872216, - 6.442097410320605, - 4.897749658726415, - 6.632989819428174, - 0.0, - 5.249404187382862, - 6.072559523278559, - 4.07661278488929, - 6.19678948003145, - 5.741508386786526, - 6.498056792320361, - 5.227077347287883, - 9.0, - 7.438852294822894, - 5.249404187382862, - 0.0, - 3.854811639654704, - 6.652724827169063, - 5.298236851430971, - 5.411470999663036, - 4.309846252268695, - 7.134101195584642, - 8.398776718824767, - 5.6631570310967465, - 6.072559523278559, - 3.854811639654704, - 0.0, - 7.529184598969917, - 6.903282911791188, - 9.0, - 6.317531174829905, - 5.457753923371659, - 3.908281400328807, - 7.579428202635459, - 4.07661278488929, - 6.652724827169063, - 7.529184598969917, - 0.0, - 7.0, - 4.977014354725805, - 6.016362684141827, - 7.0, - 4.83431066343688, - 6.760811985364303, - 6.19678948003145, - 5.298236851430971, - 6.903282911791188, - 7.0, - 0.0}, - raft::distance::DistanceType::Canberra, - 0.0}, - - {10, - {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, - {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, - 3, 4, 7, 0, 1, 2, 3, 4, 6, 8, 0, 1, 2, 5, 7, 1, 5, - 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices - {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, - 0.5167, 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, - 0.8206, 0.3625, 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, - 0.0535, 0.2225, 0.8853, 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, - 0.5279, 0.4885, 0.3495, 0.5079, 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, - 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, - {0.0, - 1.31462855332296, - 1.3690307816129905, - 1.698603990921237, - 1.3460470789553531, - 1.6636670712582544, - 1.2651744044972217, - 1.1938329352055201, - 1.8811409082590185, - 1.3653115050624267, - 1.31462855332296, - 0.0, - 1.9447722703291133, - 1.42818777206562, - 1.4685491458946494, - 1.3071999866010466, - 1.4988622861692171, - 0.9698559287406783, - 1.4972023224597841, - 1.5243383567266802, - 1.3690307816129905, - 1.9447722703291133, - 0.0, - 1.2748400840107568, - 1.0599569946448246, - 1.546591282841402, - 1.147526531928459, - 1.447002179128145, - 1.5982242387673176, - 1.3112533607072414, - 1.698603990921237, - 1.42818777206562, - 1.2748400840107568, - 0.0, - 1.038121552545461, - 1.011788365364402, - 1.3907391109256988, - 1.3128200942311496, - 1.19595706584447, - 1.3233328139624725, - 1.3460470789553531, - 1.4685491458946494, - 1.0599569946448246, - 1.038121552545461, - 0.0, - 1.3642741698145529, - 1.3493868683808095, - 1.394942694628328, - 1.572881849642552, - 1.380122665319464, - 1.6636670712582544, - 1.3071999866010466, - 1.546591282841402, - 1.011788365364402, - 1.3642741698145529, - 0.0, - 1.018961640373018, - 1.0114394258945634, - 0.8338711034820684, - 1.1247823842299223, - 1.2651744044972217, - 1.4988622861692171, - 1.147526531928459, - 1.3907391109256988, - 1.3493868683808095, - 1.018961640373018, - 0.0, - 0.7701238110357329, - 1.245486437864406, - 0.5551259549534626, - 1.1938329352055201, - 0.9698559287406783, - 1.447002179128145, - 1.3128200942311496, - 1.394942694628328, - 1.0114394258945634, - 0.7701238110357329, - 0.0, - 1.1886800117391216, - 1.0083692448135637, - 1.8811409082590185, - 1.4972023224597841, - 1.5982242387673176, - 1.19595706584447, - 1.572881849642552, - 0.8338711034820684, - 1.245486437864406, - 1.1886800117391216, - 0.0, - 1.3661374102525012, - 1.3653115050624267, - 1.5243383567266802, - 1.3112533607072414, - 1.3233328139624725, - 1.380122665319464, - 1.1247823842299223, - 0.5551259549534626, - 1.0083692448135637, - 1.3661374102525012, - 0.0}, - raft::distance::DistanceType::LpUnexpanded, - 2.0}, - - {10, - {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, - {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, - 3, 4, 7, 0, 1, 2, 3, 4, 6, 8, 0, 1, 2, 5, 7, 1, 5, - 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices - {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, - 0.5167, 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, - 0.8206, 0.3625, 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, - 0.0535, 0.2225, 0.8853, 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, - 0.5279, 0.4885, 0.3495, 0.5079, 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, - 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, - {0.0, - 0.9251771844789913, - 0.9036452083899731, - 0.9251771844789913, - 0.8706483735804971, - 0.9251771844789913, - 0.717493881903289, - 0.6920214832303888, - 0.9251771844789913, - 0.9251771844789913, - 0.9251771844789913, - 0.0, - 0.9036452083899731, - 0.8655339692155823, - 0.8706483735804971, - 0.8655339692155823, - 0.8655339692155823, - 0.6329837991017668, - 0.8655339692155823, - 0.8655339692155823, - 0.9036452083899731, - 0.9036452083899731, - 0.0, - 0.7988276152181608, - 0.7028075145996631, - 0.9036452083899731, - 0.9036452083899731, - 0.9036452083899731, - 0.8429599432532096, - 0.9036452083899731, - 0.9251771844789913, - 0.8655339692155823, - 0.7988276152181608, - 0.0, - 0.48376552205293305, - 0.8206394616536681, - 0.8206394616536681, - 0.8206394616536681, - 0.8429599432532096, - 0.8206394616536681, - 0.8706483735804971, - 0.8706483735804971, - 0.7028075145996631, - 0.48376552205293305, - 0.0, - 0.8706483735804971, - 0.8706483735804971, - 0.8706483735804971, - 0.8429599432532096, - 0.8706483735804971, - 0.9251771844789913, - 0.8655339692155823, - 0.9036452083899731, - 0.8206394616536681, - 0.8706483735804971, - 0.0, - 0.8853924473642432, - 0.535821510936138, - 0.6497196601457607, - 0.8853924473642432, - 0.717493881903289, - 0.8655339692155823, - 0.9036452083899731, - 0.8206394616536681, - 0.8706483735804971, - 0.8853924473642432, - 0.0, - 0.5279604218147174, - 0.6658348373853169, - 0.33799874888632914, - 0.6920214832303888, - 0.6329837991017668, - 0.9036452083899731, - 0.8206394616536681, - 0.8706483735804971, - 0.535821510936138, - 0.5279604218147174, - 0.0, - 0.662579808115858, - 0.5079750812968089, - 0.9251771844789913, - 0.8655339692155823, - 0.8429599432532096, - 0.8429599432532096, - 0.8429599432532096, - 0.6497196601457607, - 0.6658348373853169, - 0.662579808115858, - 0.0, - 0.8429599432532096, - 0.9251771844789913, - 0.8655339692155823, - 0.9036452083899731, - 0.8206394616536681, - 0.8706483735804971, - 0.8853924473642432, - 0.33799874888632914, - 0.5079750812968089, - 0.8429599432532096, - 0.0}, - raft::distance::DistanceType::Linf, - 0.0}, - - {4, - {0, 1, 1, 2, 4}, - {3, 2, 0, 1}, // indices - {0.99296, 0.42180, 0.11687, 0.305869}, - { - // dense output - 0.0, - 0.99296, - 1.41476, - 1.415707, - 0.99296, - 0.0, - 0.42180, - 0.42274, - 1.41476, - 0.42180, - 0.0, - 0.84454, - 1.41570, - 0.42274, - 0.84454, - 0.0, - }, - raft::distance::DistanceType::L1, - 0.0} - -}; - -typedef SparseDistanceCSRSPMVTest SparseDistanceCSRSPMVTestF; -TEST_P(SparseDistanceCSRSPMVTestF, Result) { compare(); } -INSTANTIATE_TEST_CASE_P(SparseDistanceCSRSPMVTests, SparseDistanceCSRSPMVTestF, - ::testing::ValuesIn(inputs_i32_f)); - -}; // namespace distance -}; // end namespace sparse -}; // end namespace raft diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 4247e374d6..9c2f9a4e27 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -58,41 +58,15 @@ template template class SparseDistanceTest : public ::testing::TestWithParam> { - protected: - void make_data() { - std::vector indptr_h = params.indptr_h; - std::vector indices_h = params.indices_h; - std::vector data_h = params.data_h; - - allocate(indptr, indptr_h.size()); - allocate(indices, indices_h.size()); - allocate(data, data_h.size()); - - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); - update_device(indices, indices_h.data(), indices_h.size(), stream); - update_device(data, data_h.data(), data_h.size(), stream); - - std::vector out_dists_ref_h = params.out_dists_ref_h; - - allocate(out_dists_ref, (indptr_h.size() - 1) * (indptr_h.size() - 1)); - - update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), - stream); - } + public: + SparseDistanceTest() : dist_config(handle) {} void SetUp() override { params = ::testing::TestWithParam< SparseDistanceInputs>::GetParam(); - std::shared_ptr alloc( - new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); - - CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); - CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); make_data(); - raft::sparse::distance::distances_config_t dist_config; dist_config.b_nrows = params.indptr_h.size() - 1; dist_config.b_ncols = params.n_cols; dist_config.b_nnz = params.indices_h.size(); @@ -105,9 +79,6 @@ class SparseDistanceTest dist_config.a_indptr = indptr; dist_config.a_indices = indices; dist_config.a_data = data; - dist_config.handle = cusparseHandle; - dist_config.allocator = alloc; - dist_config.stream = stream; int out_size = dist_config.a_nrows * dist_config.b_nrows; @@ -115,11 +86,11 @@ class SparseDistanceTest pairwiseDistance(out_dists, dist_config, params.metric, params.metric_arg); - CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); } void TearDown() override { - CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); CUDA_CHECK(cudaFree(indptr)); CUDA_CHECK(cudaFree(indices)); CUDA_CHECK(cudaFree(data)); @@ -134,8 +105,30 @@ class SparseDistanceTest } protected: - cudaStream_t stream; - cusparseHandle_t cusparseHandle; + void make_data() { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + allocate(indptr, indptr_h.size()); + allocate(indices, indices_h.size()); + allocate(data, data_h.size()); + + update_device(indptr, indptr_h.data(), indptr_h.size(), + handle.get_stream()); + update_device(indices, indices_h.data(), indices_h.size(), + handle.get_stream()); + update_device(data, data_h.data(), data_h.size(), handle.get_stream()); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + + allocate(out_dists_ref, (indptr_h.size() - 1) * (indptr_h.size() - 1)); + + update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), + dist_config.handle.get_stream()); + } + + raft::handle_t handle; // input data value_idx *indptr, *indices; @@ -145,6 +138,7 @@ class SparseDistanceTest value_t *out_dists, *out_dists_ref; SparseDistanceInputs params; + raft::sparse::distance::distances_config_t dist_config; }; const std::vector> inputs_i32_f = { @@ -773,7 +767,8 @@ const std::vector> inputs_i32_f = { 0.0, }, raft::distance::DistanceType::L1, - 0.0}}; + 0.0}, +}; typedef SparseDistanceTest SparseDistanceTestF; TEST_P(SparseDistanceTestF, Result) { compare(); } diff --git a/cpp/test/sparse/knn.cu b/cpp/test/sparse/knn.cu index 4759eebe4b..8c3bf36318 100644 --- a/cpp/test/sparse/knn.cu +++ b/cpp/test/sparse/knn.cu @@ -63,43 +63,10 @@ template template class SparseKNNTest : public ::testing::TestWithParam> { - protected: - void make_data() { - std::vector indptr_h = params.indptr_h; - std::vector indices_h = params.indices_h; - std::vector data_h = params.data_h; - - allocate(indptr, indptr_h.size()); - allocate(indices, indices_h.size()); - allocate(data, data_h.size()); - - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); - update_device(indices, indices_h.data(), indices_h.size(), stream); - update_device(data, data_h.data(), data_h.size(), stream); - - std::vector out_dists_ref_h = params.out_dists_ref_h; - std::vector out_indices_ref_h = params.out_indices_ref_h; - - allocate(out_indices_ref, out_indices_ref_h.size()); - allocate(out_dists_ref, out_dists_ref_h.size()); - - update_device(out_indices_ref, out_indices_ref_h.data(), - out_indices_ref_h.size(), stream); - update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), - stream); - - allocate(out_dists, n_rows * k); - allocate(out_indices, n_rows * k); - } - + public: void SetUp() override { params = ::testing::TestWithParam>::GetParam(); - std::shared_ptr alloc( - new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); - - CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); n_rows = params.indptr_h.size() - 1; nnz = params.indices_h.size(); @@ -109,16 +76,13 @@ class SparseKNNTest raft::sparse::selection::brute_force_knn( indptr, indices, data, nnz, n_rows, params.n_cols, indptr, indices, data, - nnz, n_rows, params.n_cols, out_indices, out_dists, k, cusparseHandle, - alloc, stream, params.batch_size_index, params.batch_size_query, - params.metric); + nnz, n_rows, params.n_cols, out_indices, out_dists, k, handle, + params.batch_size_index, params.batch_size_query, params.metric); - CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); } void TearDown() override { - CUDA_CHECK(cudaStreamSynchronize(stream)); - CUDA_CHECK(cudaFree(indptr)); CUDA_CHECK(cudaFree(indices)); CUDA_CHECK(cudaFree(data)); @@ -126,8 +90,6 @@ class SparseKNNTest CUDA_CHECK(cudaFree(out_dists)); CUDA_CHECK(cudaFree(out_indices_ref)); CUDA_CHECK(cudaFree(out_dists_ref)); - - CUDA_CHECK(cudaStreamDestroy(stream)); } void compare() { @@ -138,8 +100,37 @@ class SparseKNNTest } protected: - cudaStream_t stream; - cusparseHandle_t cusparseHandle; + void make_data() { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + allocate(indptr, indptr_h.size()); + allocate(indices, indices_h.size()); + allocate(data, data_h.size()); + + update_device(indptr, indptr_h.data(), indptr_h.size(), + handle.get_stream()); + update_device(indices, indices_h.data(), indices_h.size(), + handle.get_stream()); + update_device(data, data_h.data(), data_h.size(), handle.get_stream()); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + std::vector out_indices_ref_h = params.out_indices_ref_h; + + allocate(out_indices_ref, out_indices_ref_h.size()); + allocate(out_dists_ref, out_dists_ref_h.size()); + + update_device(out_indices_ref, out_indices_ref_h.data(), + out_indices_ref_h.size(), handle.get_stream()); + update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), + handle.get_stream()); + + allocate(out_dists, n_rows * k); + allocate(out_indices, n_rows * k); + } + + raft::handle_t handle; int n_rows, nnz, k;