Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split raft/spatial/knn/detail/ball_cover/registers.cuh
Browse files Browse the repository at this point in the history
This dramatically reduces the compile times of

ball_cover_knn_query.cu and ball_cover_all_knn_query.cu

They used to take 900 seconds. Now they take ~25s.
ahendriksen committed Apr 14, 2023
1 parent 40fe15b commit 95638fd
Showing 8 changed files with 345 additions and 37 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -444,6 +444,7 @@ if(RAFT_COMPILE_LIBRARY)
src/random/rmat_rectangular_generator_int64_double.cu
src/random/rmat_rectangular_generator_int_float.cu
src/random/rmat_rectangular_generator_int64_float.cu
src/spatial/knn/detail/ball_cover/registers.cu
)
set_target_properties(
raft_lib
39 changes: 2 additions & 37 deletions cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
#pragma once

#include "../haversine_distance.cuh"
#include "registers-types.cuh"
#include <cstdint>
#include <thrust/functional.h>
#include <thrust/tuple.h>
@@ -39,42 +40,6 @@ struct NNComp {
}
};

template <typename value_t, typename value_int = std::uint32_t>
struct DistFunc {
virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims)
{
return -1;
};
};

template <typename value_t, typename value_int = std::uint32_t>
struct HaversineFunc : public DistFunc<value_t, value_int> {
__device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims) override
{
return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]);
}
};

template <typename value_t, typename value_int = std::uint32_t>
struct EuclideanFunc : public DistFunc<value_t, value_int> {
__device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims) override
{
value_t sum_sq = 0;
for (value_int i = 0; i < n_dims; ++i) {
value_t diff = a[i] - b[i];
sum_sq += diff * diff;
}

return raft::sqrt(sum_sq);
}
};

/**
* Zeros the bit at location h in a one-hot encoded 32-bit int array
*/
@@ -105,4 +70,4 @@ __device__ inline bool _get_val(std::uint32_t* arr, std::uint32_t h)
}; // namespace detail
}; // namespace knn
}; // namespace spatial
}; // namespace raft
}; // namespace raft
129 changes: 129 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "../../ball_cover_types.hpp"
#include "registers-types.cuh" // DistFunc
#include <cstdint> // uint32_t
#include <raft/util/raft_explicit.hpp> //RAFT_EXPLICIT

#if defined(RAFT_EXPLICIT_INSTANTIATE)

namespace raft::spatial::knn::detail {

template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_one(raft::device_resources const& handle,
const BallCoverIndex<value_idx, value_t, value_int>& index,
const value_t* query,
const value_int n_query_rows,
value_int k,
const value_idx* R_knn_inds,
const value_t* R_knn_dists,
dist_func& dfunc,
value_idx* inds,
value_t* dists,
float weight,
value_int* dists_counter) RAFT_EXPLICIT;

template <typename value_idx,
typename value_t,
typename value_int = std::uint32_t,
int dims = 2,
typename dist_func>
void rbc_low_dim_pass_two(raft::device_resources const& handle,
const BallCoverIndex<value_idx, value_t, value_int>& index,
const value_t* query,
const value_int n_query_rows,
value_int k,
const value_idx* R_knn_inds,
const value_t* R_knn_dists,
dist_func& dfunc,
value_idx* inds,
value_t* dists,
float weight,
value_int* post_dists_counter) RAFT_EXPLICIT;

}; // namespace raft::spatial::knn::detail

#endif // RAFT_EXPLICIT_INSTANTIATE

#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \
extern template void \
raft::spatial::knn::detail::rbc_low_dim_pass_one<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \
raft::device_resources const& handle, \
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \
const Mvalue_t* query, \
const Mvalue_int n_query_rows, \
Mvalue_int k, \
const Mvalue_idx* R_knn_inds, \
const Mvalue_t* R_knn_dists, \
Mdist_func<Mvalue_t, Mvalue_int>& dfunc, \
Mvalue_idx* inds, \
Mvalue_t* dists, \
float weight, \
Mvalue_int* dists_counter)

#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \
extern template void \
raft::spatial::knn::detail::rbc_low_dim_pass_two<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \
raft::device_resources const& handle, \
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \
const Mvalue_t* query, \
const Mvalue_int n_query_rows, \
Mvalue_int k, \
const Mvalue_idx* R_knn_inds, \
const Mvalue_t* R_knn_dists, \
Mdist_func<Mvalue_t, Mvalue_int>& dfunc, \
Mvalue_idx* inds, \
Mvalue_t* dists, \
float weight, \
Mvalue_int* dists_counter)

instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc);

instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc);

#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@

#include "../../ball_cover_types.hpp"
#include "../haversine_distance.cuh"
#include "registers-types.cuh" // DistFunc

#include <cstdint>
#include <limits.h>
66 changes: 66 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-types.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "../haversine_distance.cuh" // compute_haversine
#include <cstdint> // uint32_t

namespace raft {
namespace spatial {
namespace knn {
namespace detail {

template <typename value_t, typename value_int = std::uint32_t>
struct DistFunc {
virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims)
{
return -1;
};
};

template <typename value_t, typename value_int = std::uint32_t>
struct HaversineFunc : public DistFunc<value_t, value_int> {
__device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims) override
{
return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]);
}
};

template <typename value_t, typename value_int = std::uint32_t>
struct EuclideanFunc : public DistFunc<value_t, value_int> {
__device__ __host__ __forceinline__ value_t operator()(const value_t* a,
const value_t* b,
const value_int n_dims) override
{
value_t sum_sq = 0;
for (value_int i = 0; i < n_dims; ++i) {
value_t diff = a[i] - b[i];
sum_sq += diff * diff;
}

return raft::sqrt(sum_sq);
}
};

}; // namespace detail
}; // namespace knn
}; // namespace spatial
}; // namespace raft
23 changes: 23 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#if defined(RAFT_COMPILED) && defined(RAFT_EXPLICIT_INSTANTIATE)
#include "registers-ext.cuh"
#else
#include "registers-inl.cuh"
#endif
63 changes: 63 additions & 0 deletions cpp/include/raft/util/raft_explicit.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#define RAFT_EXPLICIT \
{ \
raft::util::raft_explicit::do_not_implicitly_instantiate_templates(); \
}

namespace raft::util::raft_explicit {

// To make sure the static_assert only fires when
// do_not_implicitly_instantiate_templates is instantiated, we use a dummy
// template parameter as described in P2593:
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html
template <bool implicit_instantiation_allowed = false>
void do_not_implicitly_instantiate_templates()
{
static_assert(implicit_instantiation_allowed,
"ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n"

"If you see this error, then you have implicitly instantiated a function\n"
"template. To keep compile times in check, libfoo has the policy of\n"
"explicitly instantiating templates. To fix the compilation error, follow\n"
"these steps.\n\n"

"If you scroll up a bit in your error message, you probably saw two lines\n"
"like the following:\n\n"

"[.. snip ..] required from ‘void raft::do_not_implicitly_instantiate_templates() "
"[with int dummy = 0]’\n"
"[.. snip ..] from ‘void raft::bar(T) [with T = double]’\n\n"

"Simple solution:\n\n"

" Add '#undef RAFT_EXPLICIT_INSTANTIATE' at the top of your .cpp/.cu file.\n\n"

"Best solution:\n\n"

" 1. Add the following line to the file include/raft/bar.hpp:\n\n"

" extern template void raft::bar<double>(double);\n\n"

" 2. Add the following line to the file src/raft/bar.cpp:\n\n"

" template void raft::bar<double>(double)\n\n"

"Probability is that there are many other similar lines in both files.\n");
}

} // namespace raft::util::raft_explicit
60 changes: 60 additions & 0 deletions cpp/src/spatial/knn/detail/ball_cover/registers.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/spatial/knn/detail/ball_cover/registers-inl.cuh>

#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \
template void \
raft::spatial::knn::detail::rbc_low_dim_pass_one<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \
raft::device_resources const& handle, \
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \
const Mvalue_t* query, \
const Mvalue_int n_query_rows, \
Mvalue_int k, \
const Mvalue_idx* R_knn_inds, \
const Mvalue_t* R_knn_dists, \
raft::spatial::knn::detail::DistFunc<Mvalue_t, Mvalue_int>& dfunc, \
Mvalue_idx* inds, \
Mvalue_t* dists, \
float weight, \
Mvalue_int* dists_counter)

#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \
template void \
raft::spatial::knn::detail::rbc_low_dim_pass_two<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \
raft::device_resources const& handle, \
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \
const Mvalue_t* query, \
const Mvalue_int n_query_rows, \
Mvalue_int k, \
const Mvalue_idx* R_knn_inds, \
const Mvalue_t* R_knn_dists, \
raft::spatial::knn::detail::DistFunc<Mvalue_t, Mvalue_int>& dfunc, \
Mvalue_idx* inds, \
Mvalue_t* dists, \
float weight, \
Mvalue_int* dists_counter)

instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3);

instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2);
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3);

#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one

0 comments on commit 95638fd

Please sign in to comment.