From f25907b6405db2e9880fd6e5834ae6096e793ed6 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Wed, 26 Jul 2023 15:32:54 +0200 Subject: [PATCH] Fix sparse KNN for large batches (#1640) Answers https://github.com/rapidsai/raft/issues/1187 Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1640 --- cpp/bench/ann/src/faiss/faiss_benchmark.cu | 2 +- cpp/include/raft/sparse/detail/utils.h | 5 +++-- .../raft/sparse/distance/detail/coo_spmv.cuh | 12 ++++-------- .../raft/sparse/distance/detail/lp_distance.cuh | 16 +++++++++++----- cpp/include/raft/sparse/neighbors/detail/knn.cuh | 3 ++- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/cpp/bench/ann/src/faiss/faiss_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_benchmark.cu index 0aa4e76103..0bad86905b 100644 --- a/cpp/bench/ann/src/faiss/faiss_benchmark.cu +++ b/cpp/bench/ann/src/faiss/faiss_benchmark.cu @@ -104,10 +104,10 @@ std::unique_ptr> create_algo(const std::string& algo, // stop compiler warning; not all algorithms support multi-GPU so it may not be used (void)dev_list; - raft::bench::ann::Metric metric = parse_metric(distance); std::unique_ptr> ann; if constexpr (std::is_same_v) { + raft::bench::ann::Metric metric = parse_metric(distance); if (algo == "faiss_gpu_ivf_flat") { ann = make_algo(metric, dim, conf, dev_list); } else if (algo == "faiss_gpu_ivf_pq") { diff --git a/cpp/include/raft/sparse/detail/utils.h b/cpp/include/raft/sparse/detail/utils.h index 56e8832e0a..b5017451e6 100644 --- a/cpp/include/raft/sparse/detail/utils.h +++ b/cpp/include/raft/sparse/detail/utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,7 +90,8 @@ __global__ void iota_fill_block_kernel(value_idx* indices, value_idx ncols) int tid = threadIdx.x; for (int i = tid; i < ncols; i += blockDim.x) { - indices[row * ncols + i] = i; + uint64_t idx = (uint64_t)row * (uint64_t)ncols; + indices[idx + i] = i; } } diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 9c233ecc19..c0d5fbc365 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -56,10 +56,8 @@ inline void balanced_coo_pairwise_generalized_spmv( strategy_t strategy, int chunk_size = 500000) { - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, - 0, - sizeof(value_t) * config_.a_nrows * config_.b_nrows, - resource::get_cuda_stream(config_.handle))); + uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; + RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle))); strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); }; @@ -112,10 +110,8 @@ inline void balanced_coo_pairwise_generalized_spmv( write_f write_func, int chunk_size = 500000) { - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, - 0, - sizeof(value_t) * config_.a_nrows * config_.b_nrows, - resource::get_cuda_stream(config_.handle))); + uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; + RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle))); int max_cols = max_cols_per_block(); diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 5ee2cd7b15..ff9534a157 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -126,11 +126,13 @@ class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_t::compute(out_dists); + + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; // Sqrt Post-processing raft::linalg::unaryOp( out_dists, out_dists, - this->config_->a_nrows * this->config_->b_nrows, + n, [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; return raft::sqrt(abs(input) * neg); @@ -203,10 +205,11 @@ class lp_unexpanded_distances_t : public distances_t { raft::add_op(), raft::atomic_add_op()); + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; value_t one_over_p = value_t{1} / p; raft::linalg::unaryOp(out_dists, out_dists, - config_->a_nrows * config_->b_nrows, + n, raft::pow_const_op(one_over_p), resource::get_cuda_stream(config_->handle)); } @@ -229,10 +232,11 @@ class hamming_unexpanded_distances_t : public distances_t { unexpanded_lp_distances( out_dists, config_, raft::notequal_op(), raft::add_op(), raft::atomic_add_op()); + uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows; value_t n_cols = 1.0 / config_->a_ncols; raft::linalg::unaryOp(out_dists, out_dists, - config_->a_nrows * config_->b_nrows, + n, raft::mul_const_op(n_cols), resource::get_cuda_stream(config_->handle)); } @@ -271,10 +275,11 @@ class jensen_shannon_unexpanded_distances_t : public distances_t { raft::add_op(), raft::atomic_add_op()); + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; raft::linalg::unaryOp( out_dists, out_dists, - config_->a_nrows * config_->b_nrows, + n, [=] __device__(value_t input) { return raft::sqrt(0.5 * input); }, resource::get_cuda_stream(config_->handle)); } @@ -311,9 +316,10 @@ class kl_divergence_unexpanded_distances_t : public distances_t { raft::add_op(), raft::atomic_add_op()); + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; raft::linalg::unaryOp(out_dists, out_dists, - config_->a_nrows * config_->b_nrows, + n, raft::mul_const_op(0.5), resource::get_cuda_stream(config_->handle)); } diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index cfb1a6403b..f2be427367 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -231,7 +231,8 @@ class sparse_knn_t { /** * Compute distances */ - size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows(); + uint64_t dense_size = + (uint64_t)idx_batcher.batch_rows() * (uint64_t)query_batcher.batch_rows(); rmm::device_uvector batch_dists(dense_size, resource::get_cuda_stream(handle)); RAFT_CUDA_TRY(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t)));