diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh
index 4fc55c29b8..ff95bb56cd 100644
--- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh
@@ -65,6 +65,14 @@ void pairwise_matrix_dispatch(OpT distance_op,
cudaStream_t stream, \
bool is_row_major)
+/*
+ * Hierarchy of instantiations:
+ *
+ * This file defines extern template instantiations of the distance kernels. The
+ * instantiation of the public API is handled in raft/distance/distance-ext.cuh.
+ *
+ * After adding an instance here, make sure to also add the instance there.
+ */
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh
index 4df42ca72e..31d4dc28a1 100644
--- a/cpp/include/raft/distance/distance-ext.cuh
+++ b/cpp/include/raft/distance/distance-ext.cuh
@@ -331,6 +331,18 @@ void pairwise_distance(raft::resources const& handle,
#endif // RAFT_EXPLICIT_INSTANTIATE
+/*
+ * Hierarchy of instantiations:
+ *
+ * This file defines the extern template instantiations for the public API of
+ * raft::distance. To improve compile times, the extern template instantiation
+ * of the distance kernels is handled in
+ * distance/detail/pairwise_matrix/dispatch-ext.cuh.
+ *
+ * After adding an instance here, make sure to also add the instance to
+ * dispatch-ext.cuh and the corresponding .cu files.
+ */
+
#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \
extern template void raft::distance::distance
( \
raft::resources const& handle, \
diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch.cu
deleted file mode 100644
index 7b91b3c3bf..0000000000
--- a/cpp/src/distance/detail/pairwise_matrix/dispatch.cu
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (c) 2023, NVIDIA CORPORATION.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include
-#include // ops::*
-#include
-
-#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \
- OpT, DataT, AccT, OutT, FinOpT, IdxT) \
- template void raft::distance::detail:: \
- pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \
- OpT distance_op, \
- IdxT m, \
- IdxT n, \
- IdxT k, \
- const DataT* x, \
- const DataT* y, \
- const DataT* x_norm, \
- const DataT* y_norm, \
- OutT* out, \
- FinOpT fin_op, \
- cudaStream_t stream, \
- bool is_row_major)
-
-instantiate_raft_distance_detail_pairwise_matrix_dispatch(
- raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::canberra_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::correlation_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::correlation_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::cosine_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::cosine_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hamming_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hamming_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hellinger_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::hellinger_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::jensen_shannon_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::jensen_shannon_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::kl_divergence_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::kl_divergence_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l1_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l1_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_exp_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_exp_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_unexp_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l2_unexp_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l_inf_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::l_inf_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::lp_unexp_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::lp_unexp_distance_op,
-// double, double, double, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::russel_rao_distance_op,
-// float, float, float, raft::identity_op, int);
-// instantiate_raft_distance_detail_pairwise_matrix_dispatch(raft::distance::detail::ops::russel_rao_distance_op,
-// double, double, double, raft::identity_op, int);
-
-#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch
diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu
index f986dd30ef..91bd506724 100644
--- a/cpp/src/distance/distance.cu
+++ b/cpp/src/distance/distance.cu
@@ -16,6 +16,15 @@
#include
+/*
+ * Hierarchy of instantiations:
+ *
+ * This file defines the template instantiations for the public API of
+ * raft::distance. To improve compile times, the compilation of the distance
+ * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu.
+ *
+ */
+
#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \
template void raft::distance::distance( \
raft::resources const& handle, \