Skip to content

Commit

Permalink
Random ball cover in 3d (#510)
Browse files Browse the repository at this point in the history
Todo:
- [x] new gtests w/ `make_blobs`

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #510
  • Loading branch information
cjnolet authored Feb 24, 2022
1 parent 0bbdd4d commit 28fd5ef
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 140 deletions.
15 changes: 12 additions & 3 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies"
option(DISABLE_DEPRECATION_WARNINGS "Disable depreaction warnings " ON)
option(DISABLE_OPENMP "Disable OpenMP" OFF)
option(NVTX "Enable nvtx markers" OFF)
option(RAFT_STATIC_LINK_LIBRARIES "Statically link compiled libraft libraries")

option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ON)
option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" OFF)
Expand Down Expand Up @@ -156,6 +157,11 @@ SECTIONS
}
]=])
endif()

set(RAFT_LIB_TYPE SHARED)
if(${RAFT_STATIC_LINK_LIBRARIES})
set(RAFT_LIB_TYPE STATIC)
endif()
##############################################################################
# - raft_distance ------------------------------------------------------------
add_library(raft_distance INTERFACE)
Expand All @@ -167,7 +173,7 @@ endif()
set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)

if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib SHARED
add_library(raft_distance_lib ${RAFT_LIB_TYPE}
src/distance/specializations/detail
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
Expand Down Expand Up @@ -231,9 +237,12 @@ endif()
set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn)

if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_NN_LIBRARY)
add_library(raft_nn_lib SHARED
add_library(raft_nn_lib ${RAFT_LIB_TYPE}
src/nn/specializations/ball_cover.cu
src/nn/specializations/detail/ball_cover_lowdim.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu
src/nn/specializations/fused_l2_knn_long_float_true.cu
src/nn/specializations/fused_l2_knn_long_float_false.cu
src/nn/specializations/fused_l2_knn_int_float_true.cu
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/spatial/knn/ball_cover.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,7 @@ template <typename value_idx = std::int64_t, typename value_t, typename value_in
void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, value_int>());
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand Down Expand Up @@ -82,7 +82,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_all_knn_query(handle,
index,
Expand Down Expand Up @@ -149,7 +149,7 @@ void rbc_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_knn_query(handle,
index,
Expand Down
92 changes: 62 additions & 30 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -248,33 +248,65 @@ void perform_rbc_query(const raft::handle_t& handle,
dists + (k * n_query_pts),
std::numeric_limits<value_t>::max());

// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
if (index.n == 2) {
// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one<value_idx, value_t, value_int, 2>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two<value_idx, value_t, value_int, 2>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
}

} else if (index.n == 3) {
// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one<value_idx, value_t, value_int, 3>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two<value_idx, value_t, value_int, 3>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
}
}
}

Expand All @@ -297,7 +329,7 @@ void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
distance_func dfunc)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(!index.is_index_trained(), "index cannot be previously trained");

rmm::device_uvector<value_idx> R_knn_inds(index.m, handle.get_stream());
Expand Down Expand Up @@ -357,7 +389,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k");
ASSERT(!index.is_index_trained(), "index cannot be previously trained");

Expand Down Expand Up @@ -423,7 +455,7 @@ void rbc_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k");
ASSERT(index.is_index_trained(), "index must be previously trained");

Expand Down
31 changes: 17 additions & 14 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,6 +61,7 @@ namespace detail {
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int col_q = 2,
int tpb = 32,
typename distance_func>
__global__ void perform_post_filter_registers(const value_t* X,
Expand All @@ -87,7 +88,7 @@ __global__ void perform_post_filter_registers(const value_t* X,
__syncthreads();

// TODO: Would it be faster to use L1 for this?
value_t local_x_ptr[2];
value_t local_x_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_x_ptr[j] = X[n_cols * blockIdx.x + j];
}
Expand Down Expand Up @@ -466,6 +467,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_one(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
Expand All @@ -481,7 +483,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
value_int* dists_counter)
{
if (k <= 32)
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand Down Expand Up @@ -518,7 +520,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
dfunc,
weight);
else if (k <= 128)
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -537,7 +539,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 256)
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -556,7 +558,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 512)
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -575,7 +577,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 1024)
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -597,6 +599,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_two(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
Expand All @@ -616,7 +619,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
rmm::device_uvector<std::uint32_t> bitset(bitset_size * index.m, handle.get_stream());
thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0);

perform_post_filter_registers<value_idx, value_t, value_int, 128>
perform_post_filter_registers<value_idx, value_t, value_int, dims, 128>
<<<n_query_rows, 128, bitset_size * sizeof(std::uint32_t), handle.get_stream()>>>(
index.get_X(),
index.n,
Expand All @@ -640,7 +643,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
32,
2,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -665,7 +668,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
64,
3,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -690,7 +693,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
128,
3,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -715,7 +718,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
256,
4,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -740,7 +743,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
512,
8,
64,
2>
dims>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -765,7 +768,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
1024,
8,
64,
2>
dims>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand Down
Loading

0 comments on commit 28fd5ef

Please sign in to comment.