diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7887eb96be..f0c550ed43 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -45,6 +45,7 @@ #include #include +#include #include #include @@ -261,8 +262,11 @@ void distance_impl(raft::resources const& handle, distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using Op = ops::cosine_cutlass_op; Op distance_op{}; @@ -272,8 +276,25 @@ void distance_impl(raft::resources const& handle, } else { // Else use "legacy" cosine kernel ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch(distance_op, + m, + n, + k, + x, + y, + norm_A, + norm_B, + out, + fin_op, + stream, + is_row_major, + legacy_range); } } } @@ -527,8 +548,11 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = ops::l2_exp_cutlass_op; L2Op l2_op(perform_sqrt); @@ -536,10 +560,17 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_cutlass_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - // Else use "legacy" L2 + // Else use "legacy" L2. Compile *only* for architectures in the legacy + // range. For newer architectures, compile empty kernels. ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 9def354600..32aec6377c 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace raft::distance::detail { @@ -90,19 +91,22 @@ template -void distance_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) + typename IdxT = int, + typename SM_compat_t = raft::arch::SM_range> +void distance_matrix_dispatch( + OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) { // Determine leading dimensions and, if column-major, flip order of passing x // and y. @@ -154,7 +158,21 @@ void distance_matrix_dispatch(OpT distance_op, typedef typename std::conditional::type Policy; return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + distance_op, + fin_op, + x, + y, + x_norm, + y_norm, + m, + n, + k, + ldx, + ldy, + ld_out, + out, + stream, + sm_compat_range); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index db7ceb64f4..3f6474deeb 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::distance::detail { @@ -28,21 +29,30 @@ template -__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - opT distance_op, - FinOpT fin_op) + typename FinOpT, + typename SM_compat_t> +__global__ __launch_bounds__(Policy::Nthreads, + 2) void pairwise_matrix_kernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + opT distance_op, + FinOpT fin_op, + SM_compat_t sm_compat_range) { + // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + assert(false); + return; + } + extern __shared__ char smem[]; // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) @@ -103,7 +113,8 @@ template + typename FinOpT, + typename SM_compat_t> void pairwise_matrix(OpT distance_op, FinOpT fin_op, const DataT* x, @@ -117,18 +128,27 @@ void pairwise_matrix(OpT distance_op, IdxT ldb, IdxT ldd, OutT* dOutput, - cudaStream_t stream) + cudaStream_t stream, + SM_compat_t sm_compat_range) { dim3 blk(Policy::Nthreads); // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; + auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); kernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op, sm_compat_range); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh new file mode 100644 index 0000000000..dfa29334f5 --- /dev/null +++ b/cpp/include/raft/util/arch.cuh @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft::arch { + +/* raft::arch provides the following facilities: + * + * - raft::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * represent architectures that are always smaller and larger (respectively) + * than any architecture that can be encountered in practice. + * + * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * compute architecture that a kernel is compiled with. It can only be used + * inside kernels with a template argument. + * + * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * which version of a kernel will launch (i.e., it will return the compute + * architecture of the version of the kernel that will be launched by the + * driver). + * + * - raft::arch::SM_range : a compile-time value to represent an open interval + * of compute architectures. This can be used to check if the current + * compile-time architecture is in a specified compatibility range. + */ + +// detail::SM_generic is a template to create a generic compile-time SM +// architecture constant. +namespace detail { +template +struct SM_generic { + public: + __host__ __device__ constexpr int value() const { return n; } +}; + +// A dummy kernel that is used to determine the runtime architecture. +__global__ inline void dummy_runtime_kernel() {} +} // namespace detail + +// A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) +// and SM_MIN and SM_FUTURE, that allow specifying an open interval of +// compatible compute architectures. +using SM_min = detail::SM_generic<350>; +using SM_60 = detail::SM_generic<600>; +using SM_70 = detail::SM_generic<700>; +using SM_75 = detail::SM_generic<750>; +using SM_80 = detail::SM_generic<800>; +using SM_86 = detail::SM_generic<860>; +using SM_90 = detail::SM_generic<900>; +using SM_future = detail::SM_generic<99999>; + +// This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time +// compute architecture. It can only be used where __CUDA_ARCH__ is defined, +// i.e., inside a __global__ function template. +struct SM_compute_arch { + template + __device__ constexpr int value() const + { +#ifdef __CUDA_ARCH__ + return __CUDA_ARCH__; +#else + // This function should not be called in host code (because __CUDA_ARCH__ is + // not defined). This function is constexpr and thus can be called in host + // code (due to the --expt-relaxed-constexpr compile flag). We would like to + // provide an intelligible error message when this function is called in + // host code, which we do below. + // + // To make sure the static_assert only fires in host code, we use a dummy + // template parameter as described in P2593: + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + static_assert(dummy != 0, + "SM_compute_arch.value() is only callable from a __global__ function template. " + "A way to create a function template is by adding 'template '."); + return -1; +#endif + } +}; + +// A runtime value for the actual compute architecture of a kernel. +// +// A single kernel can be compiled for several "virtual" compute architectures. +// When a program runs, the driver picks the version of the kernel that most +// closely matches the current hardware. This struct reflects the virtual +// compute architecture of the version of the kernel that the driver picks when +// the kernel runs. +struct SM_runtime { + friend SM_runtime kernel_runtime_arch(); + + private: + const int _version; + SM_runtime(int version) : _version(version) {} + + public: + __host__ __device__ int value() const { return _version; } +}; + +// Computes which compute architecture of a kernel will run +// +// Semantics are described above in the documentation of SM_runtime. +inline SM_runtime kernel_runtime_arch() +{ + auto kernel = detail::dummy_runtime_kernel; + cudaFuncAttributes attributes; + cudaFuncGetAttributes(&attributes, kernel); + + return SM_runtime(10 * attributes.ptxVersion); +} + +// SM_range represents a range of SM architectures. It can be used to +// conditionally compile a kernel. +template +struct SM_range { + private: + const SM_MIN _min; + const SM_MAX _max; + + public: + __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) : _min(min), _max(max) {} + + template + __host__ __device__ constexpr bool contains(SM_t current) const + { + return _min.value() <= current.value() && current.value() < _max.value(); + } +}; + +} // namespace raft::arch