Skip to content

Commit

Permalink
Enable CUTLASS-based distance kernels on CTK 12 (#1702)
Browse files Browse the repository at this point in the history
The CUTLASS-based kernels were disabled on CTK 12. This PR re-enables them.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1702
  • Loading branch information
ahendriksen authored Aug 2, 2023
1 parent 1075287 commit f5833de
Showing 1 changed file with 3 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,15 @@ void pairwise_matrix_dispatch(OpT distance_op,

if (!params.is_row_major) { params.flip_x_and_y(); }

// On CUDA 12:
// - always execute normal kernel
//
// On CUDA 11 and below:
// Dispatch rule:
// - execute CUTLASS-based kernel on SM_80 and above
// - execute normal kernel below SM_80
namespace arch = raft::util::arch;

constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12;
constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op<OpT>();

if constexpr (is_ctk_12 || cutlass_op_unavailable) {
// Always execute legacy kernels on CUDA 12
if constexpr (cutlass_op_unavailable) {
// Always execute legacy kernels when no cutlass op is available
auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future());
pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream);
} else {
Expand Down

0 comments on commit f5833de

Please sign in to comment.