diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh index 4fc55c29b8..ff95bb56cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -65,6 +65,14 @@ void pairwise_matrix_dispatch(OpT distance_op, cudaStream_t stream, \ bool is_row_major) +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh index 4df42ca72e..31d4dc28a1 100644 --- a/cpp/include/raft/distance/distance-ext.cuh +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -331,6 +331,18 @@ void pairwise_distance(raft::resources const& handle, #endif // RAFT_EXPLICIT_INSTANTIATE +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * raft::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + #define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ extern template void raft::distance::distance( \ raft::resources const& handle, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch.cu deleted file mode 100644 index 7b91b3c3bf..0000000000 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch.cu +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include // ops::* -#include - -#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ - OpT, DataT, AccT, OutT, FinOpT, IdxT) \ - template void raft::distance::detail:: \ - pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ - OpT distance_op, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - const DataT* x, \ - const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ - OutT* out, \ - FinOpT fin_op, \ - cudaStream_t stream, \ - bool is_row_major) - -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::canberra_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::correlation_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::correlation_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::cosine_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::cosine_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hamming_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hamming_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hellinger_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hellinger_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::jensen_shannon_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::jensen_shannon_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::kl_divergence_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::kl_divergence_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l1_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l1_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_exp_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_exp_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_unexp_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_unexp_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l_inf_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l_inf_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::lp_unexp_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::lp_unexp_distance_op, -// double, double, double, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::russel_rao_distance_op, -// float, float, float, raft::identity_op, int); -// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::russel_rao_distance_op, -// double, double, double, raft::identity_op, int); - -#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu index f986dd30ef..91bd506724 100644 --- a/cpp/src/distance/distance.cu +++ b/cpp/src/distance/distance.cu @@ -16,6 +16,15 @@ #include +/* + * Hierarchy of instantiations: + * + * This file defines the template instantiations for the public API of + * raft::distance. To improve compile times, the compilation of the distance + * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. + * + */ + #define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ template void raft::distance::distance( \ raft::resources const& handle, \