From 6d929bb7a4c71debaaf01ba62bd3f27d1a1a23ac Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 1 Aug 2023 14:27:38 +0200 Subject: [PATCH] Enable CUTLASS-based distance kernels on CTK 12 --- .../distance/detail/pairwise_matrix/dispatch-inl.cuh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh index b768008c7f..fd9d444662 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -89,19 +89,15 @@ void pairwise_matrix_dispatch(OpT distance_op, if (!params.is_row_major) { params.flip_x_and_y(); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: + // Dispatch rule: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 namespace arch = raft::util::arch; - constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - if constexpr (is_ctk_12 || cutlass_op_unavailable) { - // Always execute legacy kernels on CUDA 12 + if constexpr (cutlass_op_unavailable) { + // Always execute legacy kernels when no cutlass op is available auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else {