diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 2ff8fa7f1c..68922943f4 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -325,12 +325,13 @@ void fusedL2NNImpl(OutT* min, decltype(distance_op), decltype(fin_op)>; - // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not // require a pointer) can be error prone. See: // https://github.com/NVIDIA/cub/issues/545 void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); if (cutlass_range.contains(runtime_arch)) { diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh index bb4422735b..b768008c7f 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -108,13 +108,14 @@ void pairwise_matrix_dispatch(OpT distance_op, auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - // Get pointer to SM60 kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not + // Get pointer to SM60 kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not // require a pointer) can be error prone. See: // https://github.com/NVIDIA/cub/issues/545 auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index dc35b10063..1a67eded44 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -30,10 +30,10 @@ namespace raft::util::arch { * compute architecture that a kernel is compiled with. It can only be used * inside kernels with a template argument. * - * - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time* - * which version of a kernel will launch (i.e., it will return the compute - * architecture of the version of the kernel that will be launched by the - * driver). + * - raft::util::arch::kernel_virtual_arch : a function that computes at *run-time* + * which version of a kernel will launch (i.e., it will return the virtual compute + * architecture of the version of the kernel that it was compiled for which + * will be launched by the driver). * * - raft::util::arch::SM_range : a compile-time value to represent an open interval * of compute architectures. This can be used to check if the current @@ -97,7 +97,7 @@ struct SM_compute_arch { // compute architecture of the version of the kernel that the driver picks when // the kernel runs. struct SM_runtime { - friend SM_runtime kernel_runtime_arch(void*); + friend SM_runtime kernel_virtual_arch(void*); private: const int _version; @@ -107,7 +107,8 @@ struct SM_runtime { __host__ __device__ int value() const { return _version; } }; -// Computes which compute architecture of a kernel will run +// Computes which virtual compute architecture the given kernel was compiled for, +// driver picks the version of the kernel that closely matches the current hardware. // // Semantics are described above in the documentation of SM_runtime. // @@ -115,7 +116,7 @@ struct SM_runtime { // to determine the architecture (that do not require a pointer) can be error // prone. See: // https://github.com/NVIDIA/cub/issues/545 -inline SM_runtime kernel_runtime_arch(void* kernel) +inline SM_runtime kernel_virtual_arch(void* kernel) { cudaFuncAttributes attributes; RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel));