diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9c54d15adc..484285bf84 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -48,6 +48,7 @@ option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies" option(DISABLE_DEPRECATION_WARNINGS "Disable depreaction warnings " ON) option(DISABLE_OPENMP "Disable OpenMP" OFF) option(NVTX "Enable nvtx markers" OFF) +option(RAFT_STATIC_LINK_LIBRARIES "Statically link compiled libraft libraries") option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ON) option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" OFF) @@ -156,6 +157,11 @@ SECTIONS } ]=]) endif() + +set(RAFT_LIB_TYPE SHARED) +if(${RAFT_STATIC_LINK_LIBRARIES}) + set(RAFT_LIB_TYPE STATIC) +endif() ############################################################################## # - raft_distance ------------------------------------------------------------ add_library(raft_distance INTERFACE) @@ -167,7 +173,7 @@ endif() set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance) if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_DIST_LIBRARY) - add_library(raft_distance_lib SHARED + add_library(raft_distance_lib ${RAFT_LIB_TYPE} src/distance/specializations/detail src/distance/specializations/detail/canberra.cu src/distance/specializations/detail/chebyshev.cu @@ -231,9 +237,12 @@ endif() set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_NN_LIBRARY) - add_library(raft_nn_lib SHARED + add_library(raft_nn_lib ${RAFT_LIB_TYPE} src/nn/specializations/ball_cover.cu - src/nn/specializations/detail/ball_cover_lowdim.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu + src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu diff --git a/cpp/include/raft/spatial/knn/ball_cover.hpp b/cpp/include/raft/spatial/knn/ball_cover.hpp index 5b93439218..d44e87710b 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ template & index) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -82,7 +82,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_all_knn_query(handle, index, @@ -149,7 +149,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_knn_query(handle, index, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 4911582ed9..d430a98ea0 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -248,33 +248,65 @@ void perform_rbc_query(const raft::handle_t& handle, dists + (k * n_query_pts), std::numeric_limits::max()); - // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - dists_counter); - - if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); + if (index.n == 2) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); + + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); + } + + } else if (index.n == 3) { + // Compute nearest k for each neighborhood in each closest R + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); + + if (perform_post_filtering) { + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); + } } } @@ -297,7 +329,7 @@ void rbc_build_index(const raft::handle_t& handle, BallCoverIndex& index, distance_func dfunc) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(!index.is_index_trained(), "index cannot be previously trained"); rmm::device_uvector R_knn_inds(index.m, handle.get_stream()); @@ -357,7 +389,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k"); ASSERT(!index.is_index_trained(), "index cannot be previously trained"); @@ -423,7 +455,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "only 2d vectors are supported in current implementation"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k"); ASSERT(index.is_index_trained(), "index must be previously trained"); diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 7c5859e043..ae9e607626 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ namespace detail { template __global__ void perform_post_filter_registers(const value_t* X, @@ -87,7 +88,7 @@ __global__ void perform_post_filter_registers(const value_t* X, __syncthreads(); // TODO: Would it be faster to use L1 for this? - value_t local_x_ptr[2]; + value_t local_x_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_x_ptr[j] = X[n_cols * blockIdx.x + j]; } @@ -466,6 +467,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, template void rbc_low_dim_pass_one(const raft::handle_t& handle, BallCoverIndex& index, @@ -481,7 +483,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, value_int* dists_counter) { if (k <= 32) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -518,7 +520,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, dfunc, weight); else if (k <= 128) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -537,7 +539,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 256) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -556,7 +558,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 512) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -575,7 +577,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, weight); else if (k <= 1024) - block_rbc_kernel_registers + block_rbc_kernel_registers <<>>(index.get_X(), query, index.n, @@ -597,6 +599,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, template void rbc_low_dim_pass_two(const raft::handle_t& handle, BallCoverIndex& index, @@ -616,7 +619,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, rmm::device_uvector bitset(bitset_size * index.m, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); - perform_post_filter_registers + perform_post_filter_registers <<>>( index.get_X(), index.n, @@ -640,7 +643,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 32, 2, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -665,7 +668,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 64, 3, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -690,7 +693,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 128, 3, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -715,7 +718,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 256, 4, 128, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -740,7 +743,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 512, 8, 64, - 2> + dims> <<>>(index.get_X(), query, index.n, @@ -765,7 +768,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 1024, 8, 64, - 2> + dims> <<>>(index.get_X(), query, index.n, diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index d0e4813332..afee3bd7a3 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace spatial { namespace knn { namespace detail { -extern template void rbc_low_dim_pass_one( +extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, BallCoverIndex& index, const float* query, @@ -37,7 +37,7 @@ extern template void rbc_low_dim_pass_one( float weight, std::uint32_t* dists_counter); -extern template void rbc_low_dim_pass_two( +extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, BallCoverIndex& index, const float* query, @@ -50,6 +50,35 @@ extern template void rbc_low_dim_pass_two( float* dists, float weight, std::uint32_t* post_dists_counter); + +extern template void rbc_low_dim_pass_one( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); + +extern template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu new file mode 100644 index 0000000000..8950ff8d5c --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template void rbc_low_dim_pass_one( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* dists_counter); + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu similarity index 92% rename from cpp/src/nn/specializations/detail/ball_cover_lowdim.cu rename to cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index dea7fe8d41..7b8b6ce9a2 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,14 @@ #include #include +#include namespace raft { namespace spatial { namespace knn { namespace detail { -template void rbc_low_dim_pass_one( +template void rbc_low_dim_pass_one( const raft::handle_t& handle, BallCoverIndex& index, const float* query, @@ -36,7 +37,7 @@ template void rbc_low_dim_pass_one( float weight, std::uint32_t* dists_counter); -template void rbc_low_dim_pass_two( +template void rbc_low_dim_pass_two( const raft::handle_t& handle, BallCoverIndex& index, const float* query, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu new file mode 100644 index 0000000000..29e8eec8c8 --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu new file mode 100644 index 0000000000..d6d4b356c8 --- /dev/null +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template void rbc_low_dim_pass_two( + const raft::handle_t& handle, + BallCoverIndex& index, + const float* query, + const std::uint32_t n_query_rows, + std::uint32_t k, + const std::int64_t* R_knn_inds, + const float* R_knn_dists, + DistFunc& dfunc, + std::int64_t* inds, + float* dists, + float weight, + std::uint32_t* post_dists_counter); +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft \ No newline at end of file diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 66cd11be1f..0cdc0d8765 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "spatial_data.h" #include #include +#include #include #include #if defined RAFT_NN_COMPILED @@ -57,13 +58,6 @@ __global__ void count_discrepancies_kernel(value_idx* actual_idx, value_t d = actual[row * n + i] - expected[row * n + i]; bool matches = (fabsf(d) <= thres) || (actual_idx[row * n + i] == expected_idx[row * n + i] && actual_idx[row * n + i] == row); - // if (!matches) { - // printf("row=%d, actual_idx=%ld, actual=%f, expected_id=%ld, - // expected=%f\n", - // row, actual_idx[row*n+i], actual[row*n+i], expected_idx[row*n+i], - // expected[row*n+i]); - // } - n_diffs += !matches; out[row] = n_diffs; } @@ -98,7 +92,8 @@ template void compute_bfknn(const raft::handle_t& handle, const value_t* X1, const value_t* X2, - uint32_t n, + uint32_t n_rows, + uint32_t n_query_rows, uint32_t d, uint32_t k, const raft::distance::DistanceType metric, @@ -106,7 +101,7 @@ void compute_bfknn(const raft::handle_t& handle, int64_t* inds) { std::vector input_vec = {const_cast(X1)}; - std::vector sizes_vec = {n}; + std::vector sizes_vec = {n_rows}; std::vector* translations = nullptr; @@ -115,7 +110,7 @@ void compute_bfknn(const raft::handle_t& handle, sizes_vec, d, const_cast(X2), - n, + n_query_rows, inds, dists, k, @@ -131,7 +126,10 @@ struct ToRadians { struct BallCoverInputs { uint32_t k; + uint32_t n_rows; + uint32_t n_cols; float weight; + uint32_t n_query; raft::distance::DistanceType metric; }; @@ -143,34 +141,31 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { params = ::testing::TestWithParam::GetParam(); raft::handle_t handle; - uint32_t k = params.k; - float weight = params.weight; - auto metric = params.metric; - - std::vector h_train_inputs = spatial_data; + uint32_t k = params.k; + uint32_t n_centers = 25; + float weight = params.weight; + auto metric = params.metric; - uint32_t n = h_train_inputs.size() / d; + rmm::device_uvector X(params.n_rows * params.n_cols, handle.get_stream()); + rmm::device_uvector Y(params.n_rows, handle.get_stream()); - rmm::device_uvector d_ref_I(n * k, handle.get_stream()); - rmm::device_uvector d_ref_D(n * k, handle.get_stream()); + raft::random::make_blobs( + X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); - // Allocate input - rmm::device_uvector d_train_inputs(n * d, handle.get_stream()); - raft::update_device(d_train_inputs.data(), h_train_inputs.data(), n * d, handle.get_stream()); + rmm::device_uvector d_ref_I(params.n_query * k, handle.get_stream()); + rmm::device_uvector d_ref_D(params.n_query * k, handle.get_stream()); if (metric == raft::distance::DistanceType::Haversine) { - thrust::transform(handle.get_thrust_policy(), - d_train_inputs.data(), - d_train_inputs.data() + d_train_inputs.size(), - d_train_inputs.data(), - ToRadians()); + thrust::transform( + handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); } compute_bfknn(handle, - d_train_inputs.data(), - d_train_inputs.data(), - n, - d, + X.data(), + X.data(), + params.n_rows, + params.n_query, + params.n_cols, k, metric, d_ref_D.data(), @@ -179,21 +174,22 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); // Allocate predicted arrays - rmm::device_uvector d_pred_I(n * k, handle.get_stream()); - rmm::device_uvector d_pred_D(n * k, handle.get_stream()); + rmm::device_uvector d_pred_I(params.n_query * k, handle.get_stream()); + rmm::device_uvector d_pred_D(params.n_query * k, handle.get_stream()); - BallCoverIndex index(handle, d_train_inputs.data(), n, d, metric); + BallCoverIndex index( + handle, X.data(), params.n_rows, params.n_cols, metric); raft::spatial::knn::rbc_build_index(handle, index); raft::spatial::knn::rbc_knn_query( - handle, index, k, d_train_inputs.data(), n, d_pred_I.data(), d_pred_D.data(), true, weight); + handle, index, k, X.data(), params.n_query, d_pred_I.data(), d_pred_D.data(), true, weight); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); // What we really want are for the distances to match exactly. The // indices may or may not match exactly, depending upon the ordering which // can be nondeterministic. - rmm::device_uvector discrepancies(n, handle.get_stream()); + rmm::device_uvector discrepancies(params.n_query, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), discrepancies.data(), discrepancies.data() + discrepancies.size(), @@ -203,7 +199,7 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { d_pred_I.data(), d_ref_D.data(), d_pred_D.data(), - n, + params.n_query, k, discrepancies.data(), handle.get_stream()); @@ -228,55 +224,44 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { params = ::testing::TestWithParam::GetParam(); raft::handle_t handle; - uint32_t k = params.k; - float weight = params.weight; - auto metric = params.metric; - - std::vector h_train_inputs = spatial_data; + uint32_t k = params.k; + uint32_t n_centers = 25; + float weight = params.weight; + auto metric = params.metric; - uint32_t n = h_train_inputs.size() / d; + rmm::device_uvector X(params.n_rows * params.n_cols, handle.get_stream()); + rmm::device_uvector Y(params.n_rows, handle.get_stream()); - rmm::device_uvector d_ref_I(n * k, handle.get_stream()); - rmm::device_uvector d_ref_D(n * k, handle.get_stream()); + raft::random::make_blobs( + X.data(), Y.data(), params.n_rows, params.n_cols, n_centers, handle.get_stream()); - // Allocate input - rmm::device_uvector d_train_inputs(n * d, handle.get_stream()); - raft::update_device(d_train_inputs.data(), h_train_inputs.data(), n * d, handle.get_stream()); + rmm::device_uvector d_ref_I(params.n_rows * k, handle.get_stream()); + rmm::device_uvector d_ref_D(params.n_rows * k, handle.get_stream()); if (metric == raft::distance::DistanceType::Haversine) { - thrust::transform(handle.get_thrust_policy(), - d_train_inputs.data(), - d_train_inputs.data() + d_train_inputs.size(), - d_train_inputs.data(), - ToRadians()); + thrust::transform( + handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); } - std::vector* translations = nullptr; - - std::vector input_vec = {d_train_inputs.data()}; - std::vector sizes_vec = {n}; - - raft::spatial::knn::detail::brute_force_knn_impl(handle, - input_vec, - sizes_vec, - d, - d_train_inputs.data(), - n, - d_ref_I.data(), - d_ref_D.data(), - k, - true, - true, - translations, - metric); + compute_bfknn(handle, + X.data(), + X.data(), + params.n_rows, + params.n_rows, + params.n_cols, + k, + metric, + d_ref_D.data(), + d_ref_I.data()); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); // Allocate predicted arrays - rmm::device_uvector d_pred_I(n * k, handle.get_stream()); - rmm::device_uvector d_pred_D(n * k, handle.get_stream()); + rmm::device_uvector d_pred_I(params.n_rows * k, handle.get_stream()); + rmm::device_uvector d_pred_D(params.n_rows * k, handle.get_stream()); - BallCoverIndex index(handle, d_train_inputs.data(), n, d, metric); + BallCoverIndex index( + handle, X.data(), params.n_rows, params.n_cols, metric); raft::spatial::knn::rbc_all_knn_query( handle, index, k, d_pred_I.data(), d_pred_D.data(), true, weight); @@ -286,7 +271,7 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { // indices may or may not match exactly, depending upon the ordering which // can be nondeterministic. - rmm::device_uvector discrepancies(n, handle.get_stream()); + rmm::device_uvector discrepancies(params.n_rows, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), discrepancies.data(), discrepancies.data() + discrepancies.size(), @@ -296,7 +281,7 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { d_pred_I.data(), d_ref_D.data(), d_pred_D.data(), - n, + params.n_rows, k, discrepancies.data(), handle.get_stream()); @@ -308,7 +293,6 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { void TearDown() override {} protected: - uint32_t d = 2; BallCoverInputs params; }; @@ -316,12 +300,15 @@ typedef BallCoverAllKNNTest BallCoverAllKNNTestF; typedef BallCoverKNNQueryTest BallCoverKNNQueryTestF; const std::vector ballcover_inputs = { - {2, 1.0, raft::distance::DistanceType::Haversine}, - {4, 1.0, raft::distance::DistanceType::Haversine}, - {7, 1.0, raft::distance::DistanceType::Haversine}, - {2, 1.0, raft::distance::DistanceType::L2SqrtUnexpanded}, - {4, 1.0, raft::distance::DistanceType::L2SqrtUnexpanded}, - {7, 1.0, raft::distance::DistanceType::L2SqrtUnexpanded}, + {2, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, + {11, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, + {25, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, + {2, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, + {11, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, + {25, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, + {2, 10000, 3, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, + {11, 10000, 3, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, + {25, 10000, 3, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, }; INSTANTIATE_TEST_CASE_P(BallCoverAllKNNTest,