From 2b467b8e9fb9a5cc76a16e76f5260174b6b8004e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 14 Feb 2022 14:00:54 -0500 Subject: [PATCH 01/12] Initializing more memory to fix stability of results --- cpp/include/raft/spatial/knn/detail/ball_cover.cuh | 13 +++++++++++++ .../spatial/knn/detail/ball_cover/registers.cuh | 1 + 2 files changed, 14 insertions(+) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 81eee717d6..9e673fd5c6 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -76,6 +76,9 @@ void sample_landmarks(const raft::handle_t& handle, thrust::fill( handle.get_thrust_policy(), R_1nn_ones.data(), R_1nn_ones.data() + R_1nn_ones.size(), 1.0); + thrust::fill( + handle.get_thrust_policy(), R_indices.data(), R_indices.data() + R_indices.size(), 0.0); + /** * 1. Randomly sample sqrt(n) points from X */ @@ -234,6 +237,16 @@ void perform_rbc_query(const raft::handle_t& handle, float weight = 1.0, bool perform_post_filtering = true) { + // initialize output inds and dists + thrust::fill(handle.get_thrust_policy(), + inds, + inds + (k * n_query_pts), + std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + dists, + dists + (k * n_query_pts), + std::numeric_limits::max()); + // Compute nearest k for each neighborhood in each closest R rbc_low_dim_pass_one(handle, index, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index a06cfd09de..6461966244 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -610,6 +610,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, const value_int bitset_size = ceil(index.n_landmarks / 32.0); rmm::device_uvector bitset(bitset_size * index.m, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); perform_post_filter_registers <<>>( From 10f31dc9bcad4c0b9180b8b1c029ac033781f007 Mon Sep 17 00:00:00 2001 From: Vinay D Date: Tue, 15 Feb 2022 11:20:58 +0530 Subject: [PATCH 02/12] Initializing device_uvector to fix issues with Random Ball Cover --- .../raft/spatial/knn/detail/ball_cover.cuh | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 9e673fd5c6..e1e2a04ee4 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -302,6 +302,16 @@ void rbc_build_index(const raft::handle_t& handle, rmm::device_uvector R_knn_inds(index.m, handle.get_stream()); rmm::device_uvector R_knn_dists(index.m, handle.get_stream()); + // Initialize the uvectors + thrust::fill(handle.get_thrust_policy(), + R_knn_inds.begin(), + R_knn_inds.end(), + std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + R_knn_dists.begin(), + R_knn_dists.end(), + std::numeric_limits::max()); + /** * 1. Randomly sample sqrt(n) points from X */ @@ -353,6 +363,16 @@ void rbc_all_knn_query(const raft::handle_t& handle, rmm::device_uvector R_knn_inds(k * index.m, handle.get_stream()); rmm::device_uvector R_knn_dists(k * index.m, handle.get_stream()); + // Initialize the uvectors + thrust::fill(handle.get_thrust_policy(), + R_knn_inds.begin(), + R_knn_inds.end(), + std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + R_knn_dists.begin(), + R_knn_dists.end(), + std::numeric_limits::max()); + // For debugging / verification. Remove before releasing rmm::device_uvector dists_counter(index.m, handle.get_stream()); rmm::device_uvector post_dists_counter(index.m, handle.get_stream()); From db4b8d77e01c5bf33e9e52bb3499166cd7bb0e2a Mon Sep 17 00:00:00 2001 From: Vinay D Date: Tue, 15 Feb 2022 11:24:36 +0530 Subject: [PATCH 03/12] Initializing few more device_uvectors --- cpp/include/raft/spatial/knn/detail/ball_cover.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index e1e2a04ee4..2b245d06cb 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -429,6 +429,16 @@ void rbc_knn_query(const raft::handle_t& handle, rmm::device_uvector R_knn_inds(k * index.m, handle.get_stream()); rmm::device_uvector R_knn_dists(k * index.m, handle.get_stream()); + // Initialize the uvectors + thrust::fill(handle.get_thrust_policy(), + R_knn_inds.begin(), + R_knn_inds.end(), + std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + R_knn_dists.begin(), + R_knn_dists.end(), + std::numeric_limits::max()); + k_closest_landmarks(handle, index, query, n_query_pts, k, R_knn_inds.data(), R_knn_dists.data()); // For debugging / verification. Remove before releasing From 016967ecb51551d19ecba1f1bb5174ab6613bf36 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 16:01:43 -0500 Subject: [PATCH 04/12] Fixing mismatching bug. Still need to figure out why heap.warpTopKRDist is max int after adding a bunch of items. --- .../knn/detail/ball_cover/registers.cuh | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 6461966244..81f71e37dc 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -96,6 +96,7 @@ __global__ void perform_post_filter_registers(const value_t* X, // zero out bits for closest k landmarks for (value_int j = threadIdx.x; j < k; j += tpb) { + int la = (int)R_knn_inds[blockIdx.x * k + j]; _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); } @@ -228,12 +229,14 @@ __global__ void compute_final_dists_registers(const value_t* X_index, for (; i < limit; i += tpb) { value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * + + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * abs(heap.warpKTopRDist - cur_candidate_dist) - heap.warpKTop * cur_candidate_dist) / heap.warpKTopRDist; - z = isnan(z) ? 0.0 : z; + z = isnan(z) || isinf(z) ? 0.0 : z; + // If lower bound on distance could possibly be in // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); @@ -261,7 +264,8 @@ __global__ void compute_final_dists_registers(const value_t* X_index, heap.warpKTop * cur_candidate_dist) / heap.warpKTopRDist; - z = isnan(z) ? 0.0 : z; + z = isnan(z) || isinf(z) ? 0.0 : z; + // If lower bound on distance could possibly be in // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); @@ -361,8 +365,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, shared_memV, k); - value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; - + value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; value_int n_dists_computed = 0; /** @@ -409,9 +412,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, heap.warpKTop * cur_candidate_dist) / heap.warpKTopRDist; - z = isnan(z) ? 0.0 : z; + z = isnan(z) || isinf(z) ? 0.0 : z; value_t dist = std::numeric_limits::max(); - if (i < k || z <= heap.warpKTop) { + + if (z <= heap.warpKTop) { const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { @@ -433,9 +437,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, heap.warpKTop * cur_candidate_dist) / heap.warpKTopRDist; - z = isnan(z) ? 0.0 : z; + z = isnan(z) || isinf(z) ? 0.0 : z; value_t dist = std::numeric_limits::max(); - if (i < k || z <= heap.warpKTop) { + + if (z <= heap.warpKTop) { const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { From 6ec636331a8b76f3e79ac102d27e0069cfcbce04 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 16:10:31 -0500 Subject: [PATCH 05/12] Removing print --- cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 81f71e37dc..7c5859e043 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -96,7 +96,6 @@ __global__ void perform_post_filter_registers(const value_t* X, // zero out bits for closest k landmarks for (value_int j = threadIdx.x; j < k; j += tpb) { - int la = (int)R_knn_inds[blockIdx.x * k + j]; _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); } From 479524446cf99114bca86b0ebeac03906ef9105f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:33:06 -0500 Subject: [PATCH 06/12] Adding rbc in 3d --- cpp/CMakeLists.txt | 5 +- cpp/include/raft/spatial/knn/ball_cover.hpp | 6 +- .../raft/spatial/knn/detail/ball_cover.cuh | 93 +++++++++++++------ .../knn/detail/ball_cover/registers.cuh | 24 ++--- .../detail/ball_cover_lowdim.hpp | 33 ++++++- ...im.cu => ball_cover_lowdim_pass_one_2d.cu} | 14 +-- .../detail/ball_cover_lowdim_pass_one_3d.cu | 56 +++++++++++ .../detail/ball_cover_lowdim_pass_two_2d.cu | 42 +++++++++ .../detail/ball_cover_lowdim_pass_two_3d.cu | 43 +++++++++ 9 files changed, 256 insertions(+), 60 deletions(-) rename cpp/src/nn/specializations/detail/{ball_cover_lowdim.cu => ball_cover_lowdim_pass_one_2d.cu} (74%) create mode 100644 cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu create mode 100644 cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu create mode 100644 cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ea0ef2c2f1..6abbf90708 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -231,7 +231,10 @@ set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_NN_LIBRARY) add_library(raft_nn_lib SHARED src/nn/specializations/ball_cover.cu - src/nn/specializations/detail/ball_cover_lowdim.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu diff --git a/cpp/include/raft/spatial/knn/ball_cover.hpp b/cpp/include/raft/spatial/knn/ball_cover.hpp index 5b93439218..78fa49b07c 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover.hpp @@ -32,7 +32,7 @@ template & index) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -82,7 +82,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_all_knn_query(handle, index, @@ -149,7 +149,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_knn_query(handle, index, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 2b245d06cb..32303be745 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -247,34 +247,67 @@ void perform_rbc_query(const raft::handle_t& handle, dists + (k * n_query_pts), std::numeric_limits::max()); - // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - dists_counter); - - if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); + if(index.n == 2) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); + + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); + } + + } else if(index.n == 3) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); + + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); + } } + } /** @@ -296,7 +329,7 @@ void rbc_build_index(const raft::handle_t& handle, BallCoverIndex& index, distance_func dfunc) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(!index.is_index_trained(), "index cannot be previously trained"); rmm::device_uvector R_knn_inds(index.m, handle.get_stream()); @@ -356,7 +389,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k"); ASSERT(!index.is_index_trained(), "index cannot be previously trained"); @@ -422,7 +455,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k"); ASSERT(index.is_index_trained(), "index must be previously trained"); diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 7c5859e043..f21592de21 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -466,6 +466,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, template void rbc_low_dim_pass_one(const raft::handle_t& handle, BallCoverIndex& index, @@ -481,7 +482,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, value_int* dists_counter) { if (k <= 32) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -518,7 +519,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, dfunc, weight); else if (k <= 128) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -537,7 +538,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 256) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -556,7 +557,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 512) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -575,7 +576,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 1024) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -597,6 +598,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, template void rbc_low_dim_pass_two(const raft::handle_t& handle, BallCoverIndex& index, @@ -640,7 +642,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 32, 2, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -665,7 +667,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 64, 3, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -690,7 +692,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 128, 3, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -715,7 +717,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 256, 4, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -740,7 +742,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 512, 8, 64, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -765,7 +767,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 1024, 8, 64, - 2> + dims> <<>>(index.get_X(), query, index.n, diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index d0e4813332..5d9fe218e2 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -23,7 +23,7 @@ namespace spatial { namespace knn { namespace detail { -extern template void rbc_low_dim_pass_one( +extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, BallCoverIndex& index, const float* query, @@ -37,7 +37,7 @@ extern template void rbc_low_dim_pass_one( float weight, std::uint32_t* dists_counter); -extern template void rbc_low_dim_pass_two( +extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, BallCoverIndex& index, const float* query, @@ -50,6 +50,35 @@ extern template void rbc_low_dim_pass_two( float* dists, float weight, std::uint32_t* post_dists_counter); + +extern template void rbc_low_dim_pass_one( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); + +extern template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu similarity index 74% rename from cpp/src/nn/specializations/detail/ball_cover_lowdim.cu rename to cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu index dea7fe8d41..50a6b0c270 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -15,6 +15,7 @@ */ #include +#include #include namespace raft { @@ -36,19 +37,6 @@ template void rbc_low_dim_pass_one( float weight, std::uint32_t* dists_counter); -template void rbc_low_dim_pass_two( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu new file mode 100644 index 0000000000..5f939fbbd0 --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { + namespace spatial { + namespace knn { + namespace detail { + + template void rbc_low_dim_pass_one( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); + + template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); + }; // namespace detail + }; // namespace knn + }; // namespace spatial +}; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu new file mode 100644 index 0000000000..8ef0ecf0b4 --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { + namespace spatial { + namespace knn { + namespace detail { + + template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); + }; // namespace detail + }; // namespace knn + }; // namespace spatial +}; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu new file mode 100644 index 0000000000..b4856d4f28 --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { + namespace spatial { + namespace knn { + namespace detail { + + + template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); + }; // namespace detail + }; // namespace knn + }; // namespace spatial +}; // namespace raft \ No newline at end of file From abb97b1d5274603957b113d69abe30696ee21ca5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:35:50 -0500 Subject: [PATCH 07/12] Style --- cpp/include/raft/spatial/knn/ball_cover.hpp | 6 +- .../raft/spatial/knn/detail/ball_cover.cuh | 75 +++++++++---------- .../knn/detail/ball_cover/registers.cuh | 4 +- .../detail/ball_cover_lowdim.hpp | 48 ++++++------ .../detail/ball_cover_lowdim_pass_one_2d.cu | 2 +- .../detail/ball_cover_lowdim_pass_one_3d.cu | 66 ++++++++-------- .../detail/ball_cover_lowdim_pass_two_2d.cu | 40 +++++----- .../detail/ball_cover_lowdim_pass_two_3d.cu | 41 +++++----- 8 files changed, 140 insertions(+), 142 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.hpp b/cpp/include/raft/spatial/knn/ball_cover.hpp index 78fa49b07c..aae6aa90e9 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover.hpp @@ -32,7 +32,7 @@ template & index) { - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -82,7 +82,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_all_knn_query(handle, index, @@ -149,7 +149,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_knn_query(handle, index, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 32303be745..d076ee44ba 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -247,9 +247,9 @@ void perform_rbc_query(const raft::handle_t& handle, dists + (k * n_query_pts), std::numeric_limits::max()); - if(index.n == 2) { - // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, + if (index.n == 2) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, index, query, n_query_pts, @@ -262,24 +262,8 @@ void perform_rbc_query(const raft::handle_t& handle, weight, dists_counter); - if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); - } - - } else if(index.n == 3) { - // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, index, query, n_query_pts, @@ -290,24 +274,39 @@ void perform_rbc_query(const raft::handle_t& handle, inds, dists, weight, - dists_counter); - - if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); - } - } + post_dists_counter); + } + + } else if (index.n == 3) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); + } + } } /** diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index f21592de21..a1add701c9 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -466,7 +466,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, template void rbc_low_dim_pass_one(const raft::handle_t& handle, BallCoverIndex& index, @@ -598,7 +598,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, template void rbc_low_dim_pass_two(const raft::handle_t& handle, BallCoverIndex& index, diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index 5d9fe218e2..396b344044 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -52,32 +52,32 @@ extern template void rbc_low_dim_pass_two std::uint32_t* post_dists_counter); extern template void rbc_low_dim_pass_one( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); extern template void rbc_low_dim_pass_two( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); }; // namespace detail }; // namespace knn diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu index 50a6b0c270..062befcc99 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -15,8 +15,8 @@ */ #include -#include #include +#include namespace raft { namespace spatial { diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index 5f939fbbd0..4edd768f0a 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -15,42 +15,42 @@ */ #include -#include #include +#include namespace raft { - namespace spatial { - namespace knn { - namespace detail { +namespace spatial { +namespace knn { +namespace detail { - template void rbc_low_dim_pass_one( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); +template void rbc_low_dim_pass_one( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); - template void rbc_low_dim_pass_two( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); - }; // namespace detail - }; // namespace knn - }; // namespace spatial +template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); +}; // namespace detail +}; // namespace knn +}; // namespace spatial }; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu index 8ef0ecf0b4..3ad0404d1e 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -15,28 +15,28 @@ */ #include -#include #include +#include namespace raft { - namespace spatial { - namespace knn { - namespace detail { +namespace spatial { +namespace knn { +namespace detail { - template void rbc_low_dim_pass_two( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); - }; // namespace detail - }; // namespace knn - }; // namespace spatial +template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); +}; // namespace detail +}; // namespace knn +}; // namespace spatial }; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu index b4856d4f28..bb9a358cf4 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -15,29 +15,28 @@ */ #include -#include #include +#include namespace raft { - namespace spatial { - namespace knn { - namespace detail { - +namespace spatial { +namespace knn { +namespace detail { - template void rbc_low_dim_pass_two( - const raft::handle_t& handle, - BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); - }; // namespace detail - }; // namespace knn - }; // namespace spatial +template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); +}; // namespace detail +}; // namespace knn +}; // namespace spatial }; // namespace raft \ No newline at end of file From ab58daa89f70899b115d054169273b74f40040a4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Feb 2022 16:59:04 -0500 Subject: [PATCH 08/12] Updating --- cpp/test/spatial/ball_cover.cu | 40 ++++++++++++++++------------------ 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index d8b5f8dbda..0b76e58736 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -18,8 +18,8 @@ #include "spatial_data.h" #include #include -#include #include +#include #include #if defined RAFT_NN_COMPILED #include @@ -140,25 +140,23 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { params = ::testing::TestWithParam::GetParam(); raft::handle_t handle; - uint32_t k = params.k; + uint32_t k = params.k; uint32_t n_centers = 25; - float weight = params.weight; - auto metric = params.metric; + float weight = params.weight; + auto metric = params.metric; rmm::device_uvector X(params.n_rows * params.n_cols, handle.get_stream()); rmm::device_uvector Y(params.n_rows, handle.get_stream()); - raft::random::make_blobs(X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); + raft::random::make_blobs( + X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); rmm::device_uvector d_ref_I(params.n_query * k, handle.get_stream()); rmm::device_uvector d_ref_D(params.n_query * k, handle.get_stream()); if (metric == raft::distance::DistanceType::Haversine) { - thrust::transform(handle.get_thrust_policy(), - X.data(), - X.data() + X.size(), - X.data(), - ToRadians()); + thrust::transform( + handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); } compute_bfknn(handle, @@ -177,7 +175,8 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_query * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_query * k, handle.get_stream()); - BallCoverIndex index(handle, X.data(), params.n_rows, params.n_cols, metric); + BallCoverIndex index( + handle, X.data(), params.n_rows, params.n_cols, metric); raft::spatial::knn::rbc_build_index(handle, index); raft::spatial::knn::rbc_knn_query( @@ -223,25 +222,23 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { params = ::testing::TestWithParam::GetParam(); raft::handle_t handle; - uint32_t k = params.k; + uint32_t k = params.k; uint32_t n_centers = 25; - float weight = params.weight; - auto metric = params.metric; + float weight = params.weight; + auto metric = params.metric; rmm::device_uvector X(params.n_rows * params.n_cols, handle.get_stream()); rmm::device_uvector Y(params.n_rows, handle.get_stream()); - raft::random::make_blobs(X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); + raft::random::make_blobs( + X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); rmm::device_uvector d_ref_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_ref_D(params.n_rows * k, handle.get_stream()); if (metric == raft::distance::DistanceType::Haversine) { - thrust::transform(handle.get_thrust_policy(), - X.data(), - X.data() + X.size(), - X.data(), - ToRadians()); + thrust::transform( + handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); } std::vector* translations = nullptr; @@ -269,7 +266,8 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_rows * k, handle.get_stream()); - BallCoverIndex index(handle, X.data(), params.n_rows, params.n_cols, metric); + BallCoverIndex index( + handle, X.data(), params.n_rows, params.n_cols, metric); raft::spatial::knn::rbc_all_knn_query( handle, index, k, d_pred_I.data(), d_pred_D.data(), true, weight); From a0f1da5f9d1feb919f45b317c79bb02cdf9c05a3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Feb 2022 17:28:08 -0500 Subject: [PATCH 09/12] Using make_blobs --- cpp/test/spatial/ball_cover.cu | 36 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 0b76e58736..ae76587dd5 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -92,7 +92,8 @@ template void compute_bfknn(const raft::handle_t& handle, const value_t* X1, const value_t* X2, - uint32_t n, + uint32_t n_rows, + uint32_t n_query_rows, uint32_t d, uint32_t k, const raft::distance::DistanceType metric, @@ -100,7 +101,7 @@ void compute_bfknn(const raft::handle_t& handle, int64_t* inds) { std::vector input_vec = {const_cast(X1)}; - std::vector sizes_vec = {n}; + std::vector sizes_vec = {n_rows}; std::vector* translations = nullptr; @@ -109,7 +110,7 @@ void compute_bfknn(const raft::handle_t& handle, sizes_vec, d, const_cast(X2), - n, + n_query_rows, inds, dists, k, @@ -162,6 +163,7 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { compute_bfknn(handle, X.data(), X.data(), + params.n_rows, params.n_query, params.n_cols, k, @@ -241,24 +243,16 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); } - std::vector* translations = nullptr; - - std::vector input_vec = {X.data()}; - std::vector sizes_vec = {params.n_rows}; - - raft::spatial::knn::detail::brute_force_knn_impl(handle, - input_vec, - sizes_vec, - params.n_cols, - X.data(), - params.n_rows, - d_ref_I.data(), - d_ref_D.data(), - k, - true, - true, - translations, - metric); + compute_bfknn(handle, + X.data(), + X.data(), + params.n_rows, + params.n_rows, + params.n_cols, + k, + metric, + d_ref_D.data(), + d_ref_I.data()); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); From b6837cfa8ba7061d27e2b340a21a68dd4fd42c6c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 17:32:50 -0500 Subject: [PATCH 10/12] Fixing bug in 3d --- cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index a1add701c9..17d3391190 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -61,6 +61,7 @@ namespace detail { template __global__ void perform_post_filter_registers(const value_t* X, @@ -87,7 +88,7 @@ __global__ void perform_post_filter_registers(const value_t* X, __syncthreads(); // TODO: Would it be faster to use L1 for this? - value_t local_x_ptr[2]; + value_t local_x_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_x_ptr[j] = X[n_cols * blockIdx.x + j]; } @@ -618,7 +619,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, rmm::device_uvector bitset(bitset_size * index.m, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); - perform_post_filter_registers + perform_post_filter_registers <<>>( index.get_X(), index.n, From a0dc97b44251eb2625354575f5f3c10d9180a9e4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 18:55:21 -0500 Subject: [PATCH 11/12] Updating copyrights --- cpp/include/raft/spatial/knn/ball_cover.hpp | 2 +- cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh | 2 +- .../spatial/knn/specializations/detail/ball_cover_lowdim.hpp | 2 +- .../nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu | 2 +- .../nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu | 2 +- .../nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu | 2 +- .../nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu | 2 +- cpp/test/spatial/ball_cover.cu | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.hpp b/cpp/include/raft/spatial/knn/ball_cover.hpp index aae6aa90e9..d44e87710b 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 17d3391190..ae9e607626 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index 396b344044..afee3bd7a3 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu index 062befcc99..8950ff8d5c 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index 4edd768f0a..7b8b6ce9a2 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu index 3ad0404d1e..29e8eec8c8 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu index bb9a358cf4..d6d4b356c8 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index ae76587dd5..0cdc0d8765 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From c8975bfab793e1e9e48dca62a184aa9b6553c281 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 21:35:55 -0500 Subject: [PATCH 12/12] Enabling optional static or shared linking for built libraries (defaults to shared) --- cpp/CMakeLists.txt | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8c4336f597..484285bf84 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -48,6 +48,7 @@ option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies" option(DISABLE_DEPRECATION_WARNINGS "Disable depreaction warnings " ON) option(DISABLE_OPENMP "Disable OpenMP" OFF) option(NVTX "Enable nvtx markers" OFF) +option(RAFT_STATIC_LINK_LIBRARIES "Statically link compiled libraft libraries") option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ON) option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" OFF) @@ -156,6 +157,11 @@ SECTIONS } ]=]) endif() + +set(RAFT_LIB_TYPE SHARED) +if(${RAFT_STATIC_LINK_LIBRARIES}) + set(RAFT_LIB_TYPE STATIC) +endif() ############################################################################## # - raft_distance ------------------------------------------------------------ add_library(raft_distance INTERFACE) @@ -167,7 +173,7 @@ endif() set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance) if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_DIST_LIBRARY) - add_library(raft_distance_lib SHARED + add_library(raft_distance_lib ${RAFT_LIB_TYPE} src/distance/specializations/detail src/distance/specializations/detail/canberra.cu src/distance/specializations/detail/chebyshev.cu @@ -231,7 +237,7 @@ endif() set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_NN_LIBRARY) - add_library(raft_nn_lib SHARED + add_library(raft_nn_lib ${RAFT_LIB_TYPE} src/nn/specializations/ball_cover.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu