From f5833de9eed8adb7f0597549db66efaaeac6c6e6 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 2 Aug 2023 04:31:22 +0200 Subject: [PATCH] Enable CUTLASS-based distance kernels on CTK 12 (#1702) The CUTLASS-based kernels were disabled on CTK 12. This PR re-enables them. Authors: - Allard Hendriksen (https://github.com/ahendriksen) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1702 --- .../distance/detail/pairwise_matrix/dispatch-inl.cuh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh index b768008c7f..fd9d444662 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -89,19 +89,15 @@ void pairwise_matrix_dispatch(OpT distance_op, if (!params.is_row_major) { params.flip_x_and_y(); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: + // Dispatch rule: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 namespace arch = raft::util::arch; - constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - if constexpr (is_ctk_12 || cutlass_op_unavailable) { - // Always execute legacy kernels on CUDA 12 + if constexpr (cutlass_op_unavailable) { + // Always execute legacy kernels when no cutlass op is available auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else {