forked from rapidsai/raft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split raft/spatial/knn/detail/ball_cover/registers.cuh
This dramatically reduces the compile times of ball_cover_knn_query.cu and ball_cover_all_knn_query.cu They used to take 900 seconds. Now they take ~25s.
1 parent
40fe15b
commit 95638fd
Showing
8 changed files
with
345 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "../../ball_cover_types.hpp" | ||
#include "registers-types.cuh" // DistFunc | ||
#include <cstdint> // uint32_t | ||
#include <raft/util/raft_explicit.hpp> //RAFT_EXPLICIT | ||
|
||
#if defined(RAFT_EXPLICIT_INSTANTIATE) | ||
|
||
namespace raft::spatial::knn::detail { | ||
|
||
template <typename value_idx, | ||
typename value_t, | ||
typename value_int = std::uint32_t, | ||
int dims = 2, | ||
typename dist_func> | ||
void rbc_low_dim_pass_one(raft::device_resources const& handle, | ||
const BallCoverIndex<value_idx, value_t, value_int>& index, | ||
const value_t* query, | ||
const value_int n_query_rows, | ||
value_int k, | ||
const value_idx* R_knn_inds, | ||
const value_t* R_knn_dists, | ||
dist_func& dfunc, | ||
value_idx* inds, | ||
value_t* dists, | ||
float weight, | ||
value_int* dists_counter) RAFT_EXPLICIT; | ||
|
||
template <typename value_idx, | ||
typename value_t, | ||
typename value_int = std::uint32_t, | ||
int dims = 2, | ||
typename dist_func> | ||
void rbc_low_dim_pass_two(raft::device_resources const& handle, | ||
const BallCoverIndex<value_idx, value_t, value_int>& index, | ||
const value_t* query, | ||
const value_int n_query_rows, | ||
value_int k, | ||
const value_idx* R_knn_inds, | ||
const value_t* R_knn_dists, | ||
dist_func& dfunc, | ||
value_idx* inds, | ||
value_t* dists, | ||
float weight, | ||
value_int* post_dists_counter) RAFT_EXPLICIT; | ||
|
||
}; // namespace raft::spatial::knn::detail | ||
|
||
#endif // RAFT_EXPLICIT_INSTANTIATE | ||
|
||
#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ | ||
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ | ||
extern template void \ | ||
raft::spatial::knn::detail::rbc_low_dim_pass_one<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \ | ||
raft::device_resources const& handle, \ | ||
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \ | ||
const Mvalue_t* query, \ | ||
const Mvalue_int n_query_rows, \ | ||
Mvalue_int k, \ | ||
const Mvalue_idx* R_knn_inds, \ | ||
const Mvalue_t* R_knn_dists, \ | ||
Mdist_func<Mvalue_t, Mvalue_int>& dfunc, \ | ||
Mvalue_idx* inds, \ | ||
Mvalue_t* dists, \ | ||
float weight, \ | ||
Mvalue_int* dists_counter) | ||
|
||
#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ | ||
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ | ||
extern template void \ | ||
raft::spatial::knn::detail::rbc_low_dim_pass_two<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \ | ||
raft::device_resources const& handle, \ | ||
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \ | ||
const Mvalue_t* query, \ | ||
const Mvalue_int n_query_rows, \ | ||
Mvalue_int k, \ | ||
const Mvalue_idx* R_knn_inds, \ | ||
const Mvalue_t* R_knn_dists, \ | ||
Mdist_func<Mvalue_t, Mvalue_int>& dfunc, \ | ||
Mvalue_idx* inds, \ | ||
Mvalue_t* dists, \ | ||
float weight, \ | ||
Mvalue_int* dists_counter) | ||
|
||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); | ||
|
||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( | ||
std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); | ||
|
||
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two | ||
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
cpp/include/raft/spatial/knn/detail/ball_cover/registers-types.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "../haversine_distance.cuh" // compute_haversine | ||
#include <cstdint> // uint32_t | ||
|
||
namespace raft { | ||
namespace spatial { | ||
namespace knn { | ||
namespace detail { | ||
|
||
template <typename value_t, typename value_int = std::uint32_t> | ||
struct DistFunc { | ||
virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, | ||
const value_t* b, | ||
const value_int n_dims) | ||
{ | ||
return -1; | ||
}; | ||
}; | ||
|
||
template <typename value_t, typename value_int = std::uint32_t> | ||
struct HaversineFunc : public DistFunc<value_t, value_int> { | ||
__device__ __host__ __forceinline__ value_t operator()(const value_t* a, | ||
const value_t* b, | ||
const value_int n_dims) override | ||
{ | ||
return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); | ||
} | ||
}; | ||
|
||
template <typename value_t, typename value_int = std::uint32_t> | ||
struct EuclideanFunc : public DistFunc<value_t, value_int> { | ||
__device__ __host__ __forceinline__ value_t operator()(const value_t* a, | ||
const value_t* b, | ||
const value_int n_dims) override | ||
{ | ||
value_t sum_sq = 0; | ||
for (value_int i = 0; i < n_dims; ++i) { | ||
value_t diff = a[i] - b[i]; | ||
sum_sq += diff * diff; | ||
} | ||
|
||
return raft::sqrt(sum_sq); | ||
} | ||
}; | ||
|
||
}; // namespace detail | ||
}; // namespace knn | ||
}; // namespace spatial | ||
}; // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#if defined(RAFT_COMPILED) && defined(RAFT_EXPLICIT_INSTANTIATE) | ||
#include "registers-ext.cuh" | ||
#else | ||
#include "registers-inl.cuh" | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#define RAFT_EXPLICIT \ | ||
{ \ | ||
raft::util::raft_explicit::do_not_implicitly_instantiate_templates(); \ | ||
} | ||
|
||
namespace raft::util::raft_explicit { | ||
|
||
// To make sure the static_assert only fires when | ||
// do_not_implicitly_instantiate_templates is instantiated, we use a dummy | ||
// template parameter as described in P2593: | ||
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html | ||
template <bool implicit_instantiation_allowed = false> | ||
void do_not_implicitly_instantiate_templates() | ||
{ | ||
static_assert(implicit_instantiation_allowed, | ||
"ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n" | ||
|
||
"If you see this error, then you have implicitly instantiated a function\n" | ||
"template. To keep compile times in check, libfoo has the policy of\n" | ||
"explicitly instantiating templates. To fix the compilation error, follow\n" | ||
"these steps.\n\n" | ||
|
||
"If you scroll up a bit in your error message, you probably saw two lines\n" | ||
"like the following:\n\n" | ||
|
||
"[.. snip ..] required from ‘void raft::do_not_implicitly_instantiate_templates() " | ||
"[with int dummy = 0]’\n" | ||
"[.. snip ..] from ‘void raft::bar(T) [with T = double]’\n\n" | ||
|
||
"Simple solution:\n\n" | ||
|
||
" Add '#undef RAFT_EXPLICIT_INSTANTIATE' at the top of your .cpp/.cu file.\n\n" | ||
|
||
"Best solution:\n\n" | ||
|
||
" 1. Add the following line to the file include/raft/bar.hpp:\n\n" | ||
|
||
" extern template void raft::bar<double>(double);\n\n" | ||
|
||
" 2. Add the following line to the file src/raft/bar.cpp:\n\n" | ||
|
||
" template void raft::bar<double>(double)\n\n" | ||
|
||
"Probability is that there are many other similar lines in both files.\n"); | ||
} | ||
|
||
} // namespace raft::util::raft_explicit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <raft/spatial/knn/detail/ball_cover/registers-inl.cuh> | ||
|
||
#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ | ||
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ | ||
template void \ | ||
raft::spatial::knn::detail::rbc_low_dim_pass_one<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \ | ||
raft::device_resources const& handle, \ | ||
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \ | ||
const Mvalue_t* query, \ | ||
const Mvalue_int n_query_rows, \ | ||
Mvalue_int k, \ | ||
const Mvalue_idx* R_knn_inds, \ | ||
const Mvalue_t* R_knn_dists, \ | ||
raft::spatial::knn::detail::DistFunc<Mvalue_t, Mvalue_int>& dfunc, \ | ||
Mvalue_idx* inds, \ | ||
Mvalue_t* dists, \ | ||
float weight, \ | ||
Mvalue_int* dists_counter) | ||
|
||
#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ | ||
Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ | ||
template void \ | ||
raft::spatial::knn::detail::rbc_low_dim_pass_two<Mvalue_idx, Mvalue_t, Mvalue_int, Mdims>( \ | ||
raft::device_resources const& handle, \ | ||
const BallCoverIndex<Mvalue_idx, Mvalue_t, Mvalue_int>& index, \ | ||
const Mvalue_t* query, \ | ||
const Mvalue_int n_query_rows, \ | ||
Mvalue_int k, \ | ||
const Mvalue_idx* R_knn_inds, \ | ||
const Mvalue_t* R_knn_dists, \ | ||
raft::spatial::knn::detail::DistFunc<Mvalue_t, Mvalue_int>& dfunc, \ | ||
Mvalue_idx* inds, \ | ||
Mvalue_t* dists, \ | ||
float weight, \ | ||
Mvalue_int* dists_counter) | ||
|
||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); | ||
|
||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); | ||
instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); | ||
|
||
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two | ||
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one |