Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random ball cover in 3d #510

Merged
merged 14 commits into from
Feb 24, 2022
Merged
15 changes: 12 additions & 3 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies"
option(DISABLE_DEPRECATION_WARNINGS "Disable depreaction warnings " ON)
option(DISABLE_OPENMP "Disable OpenMP" OFF)
option(NVTX "Enable nvtx markers" OFF)
option(RAFT_STATIC_LINK_LIBRARIES "Statically link compiled libraft libraries")

option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ON)
option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" OFF)
Expand Down Expand Up @@ -156,6 +157,11 @@ SECTIONS
}
]=])
endif()

set(RAFT_LIB_TYPE SHARED)
if(${RAFT_STATIC_LINK_LIBRARIES})
set(RAFT_LIB_TYPE STATIC)
endif()
##############################################################################
# - raft_distance ------------------------------------------------------------
add_library(raft_distance INTERFACE)
Expand All @@ -167,7 +173,7 @@ endif()
set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)

if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib SHARED
add_library(raft_distance_lib ${RAFT_LIB_TYPE}
src/distance/specializations/detail
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
Expand Down Expand Up @@ -231,9 +237,12 @@ endif()
set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn)

if(RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_NN_LIBRARY)
add_library(raft_nn_lib SHARED
add_library(raft_nn_lib ${RAFT_LIB_TYPE}
src/nn/specializations/ball_cover.cu
src/nn/specializations/detail/ball_cover_lowdim.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu
src/nn/specializations/fused_l2_knn_long_float_true.cu
src/nn/specializations/fused_l2_knn_long_float_false.cu
src/nn/specializations/fused_l2_knn_int_float_true.cu
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/spatial/knn/ball_cover.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,7 @@ template <typename value_idx = std::int64_t, typename value_t, typename value_in
void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, value_int>());
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand Down Expand Up @@ -82,7 +82,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_all_knn_query(handle,
index,
Expand Down Expand Up @@ -149,7 +149,7 @@ void rbc_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_knn_query(handle,
index,
Expand Down
92 changes: 62 additions & 30 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -247,33 +247,65 @@ void perform_rbc_query(const raft::handle_t& handle,
dists + (k * n_query_pts),
std::numeric_limits<value_t>::max());

// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
if (index.n == 2) {
// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one<value_idx, value_t, value_int, 2>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two<value_idx, value_t, value_int, 2>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
}

} else if (index.n == 3) {
// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one<value_idx, value_t, value_int, 3>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
dists_counter);

if (perform_post_filtering) {
rbc_low_dim_pass_two<value_idx, value_t, value_int, 3>(handle,
index,
query,
n_query_pts,
k,
R_knn_inds,
R_knn_dists,
dfunc,
inds,
dists,
weight,
post_dists_counter);
}
}
}

Expand All @@ -296,7 +328,7 @@ void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
distance_func dfunc)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(!index.is_index_trained(), "index cannot be previously trained");

rmm::device_uvector<value_idx> R_knn_inds(index.m, handle.get_stream());
Expand Down Expand Up @@ -356,7 +388,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k");
ASSERT(!index.is_index_trained(), "index cannot be previously trained");

Expand Down Expand Up @@ -422,7 +454,7 @@ void rbc_knn_query(const raft::handle_t& handle,
bool perform_post_filtering = true,
float weight = 1.0)
{
ASSERT(index.n == 2, "only 2d vectors are supported in current implementation");
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(index.n_landmarks >= k, "number of landmark samples must be >= k");
ASSERT(index.is_index_trained(), "index must be previously trained");

Expand Down
31 changes: 17 additions & 14 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,6 +61,7 @@ namespace detail {
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int col_q = 2,
int tpb = 32,
typename distance_func>
__global__ void perform_post_filter_registers(const value_t* X,
Expand All @@ -87,7 +88,7 @@ __global__ void perform_post_filter_registers(const value_t* X,
__syncthreads();

// TODO: Would it be faster to use L1 for this?
value_t local_x_ptr[2];
value_t local_x_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_x_ptr[j] = X[n_cols * blockIdx.x + j];
}
Expand Down Expand Up @@ -466,6 +467,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_one(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
Expand All @@ -481,7 +483,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
value_int* dists_counter)
{
if (k <= 32)
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand Down Expand Up @@ -518,7 +520,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
dfunc,
weight);
else if (k <= 128)
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -537,7 +539,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 256)
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, dims, value_int>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -556,7 +558,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 512)
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -575,7 +577,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
weight);

else if (k <= 1024)
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, 2, value_int>
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -597,6 +599,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle,
template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_two(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
Expand All @@ -616,7 +619,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
rmm::device_uvector<std::uint32_t> bitset(bitset_size * index.m, handle.get_stream());
thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0);

perform_post_filter_registers<value_idx, value_t, value_int, 128>
perform_post_filter_registers<value_idx, value_t, value_int, dims, 128>
<<<n_query_rows, 128, bitset_size * sizeof(std::uint32_t), handle.get_stream()>>>(
index.get_X(),
index.n,
Expand All @@ -640,7 +643,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
32,
2,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -665,7 +668,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
64,
3,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -690,7 +693,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
128,
3,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -715,7 +718,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
256,
4,
128,
2>
dims>
<<<n_query_rows, 128, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -740,7 +743,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
512,
8,
64,
2>
dims>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand All @@ -765,7 +768,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
1024,
8,
64,
2>
dims>
<<<n_query_rows, 64, 0, handle.get_stream()>>>(index.get_X(),
query,
index.n,
Expand Down
Loading