Skip to content

Commit

Permalink
Address review re: distance API
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen committed Apr 18, 2023
1 parent 65c1cba commit 3ea52b8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ void pairwise_matrix_dispatch(OpT distance_op,
cudaStream_t stream, \
bool is_row_major)

/*
* Hierarchy of instantiations:
*
* This file defines extern template instantiations of the distance kernels. The
* instantiation of the public API is handled in raft/distance/distance-ext.cuh.
*
* After adding an instance here, make sure to also add the instance there.
*/
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
Expand Down
12 changes: 12 additions & 0 deletions cpp/include/raft/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,18 @@ void pairwise_distance(raft::resources const& handle,

#endif // RAFT_EXPLICIT_INSTANTIATE

/*
* Hierarchy of instantiations:
*
* This file defines the extern template instantiations for the public API of
* raft::distance. To improve compile times, the extern template instantiation
* of the distance kernels is handled in
* distance/detail/pairwise_matrix/dispatch-ext.cuh.
*
* After adding an instance here, make sure to also add the instance to
* dispatch-ext.cuh and the corresponding .cu files.
*/

#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \
extern template void raft::distance::distance<DT, DataT, AccT, OutT, FinalLambda, IdxT>( \
raft::resources const& handle, \
Expand Down
91 changes: 0 additions & 91 deletions cpp/src/distance/detail/pairwise_matrix/dispatch.cu

This file was deleted.

9 changes: 9 additions & 0 deletions cpp/src/distance/distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

#include <raft/distance/distance-inl.cuh>

/*
* Hierarchy of instantiations:
*
* This file defines the template instantiations for the public API of
* raft::distance. To improve compile times, the compilation of the distance
* kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu.
*
*/

#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \
template void raft::distance::distance<DT, DataT, AccT, OutT, FinalLambda, IdxT>( \
raft::resources const& handle, \
Expand Down

0 comments on commit 3ea52b8

Please sign in to comment.