From 76c828dd4da4dc922626ba2a440a46dea6ab03b9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Sat, 25 Mar 2023 20:09:35 +0100 Subject: [PATCH] Add extern template for ivfflat_interleaved_scan (#1360) This should cut compilation time for refine_d_int64_t_float.cu.o et al from ~900 seconds to 29 seconds. The refine specialization contain >100 instances of the ivfflat_interleaved_scan kernel, even though these should be seperately compiled by the ivfflat_search specializations. The call to ivf_flat_interleaved_scan is [here](https://github.com/rapidsai/raft/blob/56ac43ad93a319a61073dce1b3b937f6f13ade63/cpp/include/raft/neighbors/detail/refine.cuh#L121). Depends on (so please merge after) PR #1307. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1360 --- .../raft/neighbors/detail/ivf_flat_search.cuh | 8 ++++ cpp/include/raft/neighbors/detail/refine.cuh | 8 ++++ .../neighbors/specializations/ivf_flat.cuh | 25 ++++++++++++- .../ivfflat_search_float_int64_t.cu | 37 ++++++++++++++++--- .../ivfflat_search_int8_t_int64_t.cu | 28 +++++++++++--- .../ivfflat_search_uint8_t_int64_t.cu | 28 +++++++++++--- 6 files changed, 115 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index f657070df4..e6533eaf51 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -1065,6 +1065,14 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // related function calls after editing this function definition. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. + const int capacity = bound_by_power_of_two(k); select_interleaved_scan_kernel::run(capacity, index.veclen(), diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index f244d5875c..aedfc42698 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -117,6 +117,14 @@ void refine_device(raft::device_resources const& handle, n_queries, n_candidates); + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // and adjust the extern template definition and the instantiation when the + // below function call is edited. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 013c7359e5..161f3462c9 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -20,6 +20,13 @@ namespace raft::neighbors::ivf_flat { +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided here. Please check related function +// calls after editing template definition below. Search for +// `greppable-id-specializations-ivf-flat-search` to find them. #define RAFT_INST(T, IdxT) \ extern template auto build(raft::device_resources const& handle, \ const index_params& params, \ @@ -44,7 +51,23 @@ namespace raft::neighbors::ivf_flat { const raft::neighbors::ivf_flat::index&, \ raft::device_matrix_view, \ raft::device_matrix_view, \ - raft::device_matrix_view); + raft::device_matrix_view); \ + \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); RAFT_INST(float, int64_t); RAFT_INST(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu index 6de65546c8..dce7083139 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu @@ -18,12 +18,37 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided. To make sure +// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate +// it below. Please check related function calls after editing template +// definition below. Search for `greppable-id-specializations-ivf-flat-search` +// to find them. +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu index 8eda240ccd..b03d878bae 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu index 8ff6533628..2d42bae0d1 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(uint8_t, int64_t);