diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e3cdcbf760..94b24be853 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -444,6 +444,7 @@ if(RAFT_COMPILE_LIBRARY) src/random/rmat_rectangular_generator_int64_double.cu src/random/rmat_rectangular_generator_int_float.cu src/random/rmat_rectangular_generator_int64_float.cu + src/spatial/knn/detail/ball_cover/registers.cu ) set_target_properties( raft_lib diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh index 0a6718f5a5..5522e867fd 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh @@ -17,6 +17,7 @@ #pragma once #include "../haversine_distance.cuh" +#include "registers-types.cuh" #include #include #include @@ -39,42 +40,6 @@ struct NNComp { } }; -template -struct DistFunc { - virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) - { - return -1; - }; -}; - -template -struct HaversineFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); - } -}; - -template -struct EuclideanFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - value_t sum_sq = 0; - for (value_int i = 0; i < n_dims; ++i) { - value_t diff = a[i] - b[i]; - sum_sq += diff * diff; - } - - return raft::sqrt(sum_sq); - } -}; - /** * Zeros the bit at location h in a one-hot encoded 32-bit int array */ @@ -105,4 +70,4 @@ __device__ inline bool _get_val(std::uint32_t* arr, std::uint32_t h) }; // namespace detail }; // namespace knn }; // namespace spatial -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh new file mode 100644 index 0000000000..b5b54c62a7 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../ball_cover_types.hpp" +#include "registers-types.cuh" // DistFunc +#include // uint32_t +#include //RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE) + +namespace raft::spatial::knn::detail { + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) RAFT_EXPLICIT; + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) RAFT_EXPLICIT; + +}; // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index f665368c41..9c624dcb08 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -20,6 +20,7 @@ #include "../../ball_cover_types.hpp" #include "../haversine_distance.cuh" +#include "registers-types.cuh" // DistFunc #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-types.cuh new file mode 100644 index 0000000000..7f4268d2dc --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-types.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../haversine_distance.cuh" // compute_haversine +#include // uint32_t + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +struct DistFunc { + virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) + { + return -1; + }; +}; + +template +struct HaversineFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); + } +}; + +template +struct EuclideanFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + + return raft::sqrt(sum_sq); + } +}; + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index e69de29bb2..399d4b07c6 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if defined(RAFT_COMPILED) && defined(RAFT_EXPLICIT_INSTANTIATE) +#include "registers-ext.cuh" +#else +#include "registers-inl.cuh" +#endif diff --git a/cpp/include/raft/util/raft_explicit.hpp b/cpp/include/raft/util/raft_explicit.hpp new file mode 100644 index 0000000000..fd81fe23de --- /dev/null +++ b/cpp/include/raft/util/raft_explicit.hpp @@ -0,0 +1,63 @@ +/* Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#define RAFT_EXPLICIT \ + { \ + raft::util::raft_explicit::do_not_implicitly_instantiate_templates(); \ + } + +namespace raft::util::raft_explicit { + +// To make sure the static_assert only fires when +// do_not_implicitly_instantiate_templates is instantiated, we use a dummy +// template parameter as described in P2593: +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +template +void do_not_implicitly_instantiate_templates() +{ + static_assert(implicit_instantiation_allowed, + "ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n" + + "If you see this error, then you have implicitly instantiated a function\n" + "template. To keep compile times in check, libfoo has the policy of\n" + "explicitly instantiating templates. To fix the compilation error, follow\n" + "these steps.\n\n" + + "If you scroll up a bit in your error message, you probably saw two lines\n" + "like the following:\n\n" + + "[.. snip ..] required from ‘void raft::do_not_implicitly_instantiate_templates() " + "[with int dummy = 0]’\n" + "[.. snip ..] from ‘void raft::bar(T) [with T = double]’\n\n" + + "Simple solution:\n\n" + + " Add '#undef RAFT_EXPLICIT_INSTANTIATE' at the top of your .cpp/.cu file.\n\n" + + "Best solution:\n\n" + + " 1. Add the following line to the file include/raft/bar.hpp:\n\n" + + " extern template void raft::bar(double);\n\n" + + " 2. Add the following line to the file src/raft/bar.cpp:\n\n" + + " template void raft::bar(double)\n\n" + + "Probability is that there are many other similar lines in both files.\n"); +} + +} // namespace raft::util::raft_explicit diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers.cu b/cpp/src/spatial/knn/detail/ball_cover/registers.cu new file mode 100644 index 0000000000..0bb6d123a9 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one