Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sparse KNN for large batches #1640

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
// stop compiler warning; not all algorithms support multi-GPU so it may not be used
(void)dev_list;

raft::bench::ann::Metric metric = parse_metric(distance);
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {
raft::bench::ann::Metric metric = parse_metric(distance);
if (algo == "faiss_gpu_ivf_flat") {
ann = make_algo<T, raft::bench::ann::FaissGpuIVFFlat>(metric, dim, conf, dev_list);
} else if (algo == "faiss_gpu_ivf_pq") {
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/sparse/detail/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ __global__ void iota_fill_block_kernel(value_idx* indices, value_idx ncols)
int tid = threadIdx.x;

for (int i = tid; i < ncols; i += blockDim.x) {
indices[row * ncols + i] = i;
uint64_t idx = (uint64_t)row * (uint64_t)ncols;
indices[idx + i] = i;
}
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/sparse/distance/detail/coo_spmv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ inline void balanced_coo_pairwise_generalized_spmv(
strategy_t strategy,
int chunk_size = 500000)
{
uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows;
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists,
0,
sizeof(value_t) * config_.a_nrows * config_.b_nrows,
n,
resource::get_cuda_stream(config_.handle)));

strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size);
Expand Down Expand Up @@ -112,9 +113,10 @@ inline void balanced_coo_pairwise_generalized_spmv(
write_f write_func,
int chunk_size = 500000)
{
uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows;
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists,
0,
sizeof(value_t) * config_.a_nrows * config_.b_nrows,
n,
resource::get_cuda_stream(config_.handle)));

int max_cols = max_cols_per_block<value_idx, value_t>();
Expand Down
16 changes: 11 additions & 5 deletions cpp/include/raft/sparse/distance/detail/lp_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_t<value_id
void compute(value_t* out_dists)
{
l2_unexpanded_distances_t<value_idx, value_t>::compute(out_dists);

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
// Sqrt Post-processing
raft::linalg::unaryOp<value_t>(
out_dists,
out_dists,
this->config_->a_nrows * this->config_->b_nrows,
n,
[] __device__(value_t input) {
int neg = input < 0 ? -1 : 1;
return raft::sqrt(abs(input) * neg);
Expand Down Expand Up @@ -203,10 +205,11 @@ class lp_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
value_t one_over_p = value_t{1} / p;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::pow_const_op<value_t>(one_over_p),
resource::get_cuda_stream(config_->handle));
}
Expand All @@ -229,10 +232,11 @@ class hamming_unexpanded_distances_t : public distances_t<value_t> {
unexpanded_lp_distances<value_idx, value_t>(
out_dists, config_, raft::notequal_op(), raft::add_op(), raft::atomic_add_op());

uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows;
value_t n_cols = 1.0 / config_->a_ncols;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::mul_const_op<value_t>(n_cols),
resource::get_cuda_stream(config_->handle));
}
Expand Down Expand Up @@ -271,10 +275,11 @@ class jensen_shannon_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
raft::linalg::unaryOp<value_t>(
out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
[=] __device__(value_t input) { return raft::sqrt(0.5 * input); },
resource::get_cuda_stream(config_->handle));
}
Expand Down Expand Up @@ -311,9 +316,10 @@ class kl_divergence_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::mul_const_op<value_t>(0.5),
resource::get_cuda_stream(config_->handle));
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/neighbors/detail/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class sparse_knn_t {
/**
* Compute distances
*/
size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows();
uint64_t dense_size = (uint64_t)idx_batcher.batch_rows() * (uint64_t)query_batcher.batch_rows();
rmm::device_uvector<value_t> batch_dists(dense_size, resource::get_cuda_stream(handle));

RAFT_CUDA_TRY(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t)));
Expand Down