From 749d000dbfd05a6c429eec3ab8475c21b6e38e65 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 16:03:46 +0100 Subject: [PATCH 1/4] Add dispatch based on compute architecture --- cpp/include/raft/distance/detail/distance.cuh | 26 ++-- .../detail/pairwise_matrix/dispatch.cuh | 9 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 23 ++- cpp/include/raft/util/arch.cuh | 133 ++++++++++++++++++ 4 files changed, 173 insertions(+), 18 deletions(-) create mode 100644 cpp/include/raft/util/arch.cuh diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 621e2d15b9..da119b6a45 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -46,6 +46,7 @@ #include #include #include +#include #include namespace raft { @@ -261,8 +262,11 @@ void distance_impl(raft::resources const& handle, distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using Op = ops::cosine_cutlass_op; Op distance_op{}; @@ -272,8 +276,8 @@ void distance_impl(raft::resources const& handle, } else { // Else use "legacy" L2 ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } } @@ -527,8 +531,11 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = ops::l2_exp_cutlass_op; L2Op l2_op(perform_sqrt); @@ -536,10 +543,11 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_cutlass_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - // Else use "legacy" L2 + // Else use "legacy" L2. Compile *only* for architectures in the legacy + // range. For newer architectures, compile empty kernels. ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 23d0f34489..75680027d8 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -17,6 +17,7 @@ #include "kernel_sm60.cuh" #include +#include #include #include #include @@ -89,7 +90,8 @@ template + typename IdxT = int, + typename SM_compat_t = raft::arch::SM_range> void distance_matrix_dispatch(OpT distance_op, IdxT m, IdxT n, @@ -101,7 +103,8 @@ void distance_matrix_dispatch(OpT distance_op, OutT* out, FinOpT fin_op, cudaStream_t stream, - bool is_row_major) + bool is_row_major, + SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) { // Determine leading dimensions and, if column-major, flip order of passing x // and y. @@ -145,7 +148,7 @@ void distance_matrix_dispatch(OpT distance_op, typedef typename std::conditional::type Policy; return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream, sm_compat_range); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index db7ceb64f4..7404c47e66 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include @@ -28,7 +29,8 @@ template + typename FinOpT, + typename SM_compat_t> __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, const DataT* y, const DataT* _xn, @@ -41,8 +43,15 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(co IdxT ldd, OutT* dOutput, opT distance_op, - FinOpT fin_op) + FinOpT fin_op, + SM_compat_t sm_compat_range) { + // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) { + assert(false); + return; + } + extern __shared__ char smem[]; // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) @@ -103,7 +112,8 @@ template + typename FinOpT, + typename SM_compat_t> void pairwise_matrix(OpT distance_op, FinOpT fin_op, const DataT* x, @@ -117,18 +127,19 @@ void pairwise_matrix(OpT distance_op, IdxT ldb, IdxT ldd, OutT* dOutput, - cudaStream_t stream) + cudaStream_t stream, + SM_compat_t sm_compat_range) { dim3 blk(Policy::Nthreads); // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; + auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); kernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op, sm_compat_range); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh new file mode 100644 index 0000000000..554805da9e --- /dev/null +++ b/cpp/include/raft/util/arch.cuh @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft::arch { + +/* raft::arch provides the following facilities: + * + * - raft::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * represent architectures that are always smaller and larger (respectively) + * than any architecture that can be encountered in practice. + * + * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * compute architecture that a kernel is compiled with. It can only be used + * inside kernels with a template argument. + * + * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * which version of a kernel will launch (i.e., it will return the compute + * architecture of the version of the kernel that will be launched by the + * driver). + * + * - raft::arch::SM_range : a compile-time value to represent an open interval + * of compute architectures. This can be used to check if the current + * compile-time architecture is in a specified compatibility range. + */ + +// inner::SM_generic is a template to create a generic compile-time SM +// architecture constant. +namespace inner { +template +struct SM_generic { +public: + __host__ __device__ constexpr int value() const { + return n; + } +}; + +// A +__global__ inline void dummy_runtime_kernel() {} +} + +// A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) +// and SM_MIN and SM_FUTURE, that allow specifying an open interval of +// compatible compute architectures. +using SM_min = inner::SM_generic<350>; +using SM_60 = inner::SM_generic<600>; +using SM_70 = inner::SM_generic<700>; +using SM_75 = inner::SM_generic<750>; +using SM_80 = inner::SM_generic<800>; +using SM_86 = inner::SM_generic<860>; +using SM_90 = inner::SM_generic<900>; +using SM_future = inner::SM_generic<99999>; + +// This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time +// compute architecture. It can only be used where __CUDA_ARCH__ is defined, +// i.e., inside a __global__ function template. +struct SM_compute_arch { + template + __host__ __device__ constexpr int value() const { +#ifdef __CUDA_ARCH__ + return __CUDA_ARCH__; +#else + static_assert(dummy != 0, + "SM_compute_arch.value() is only callable from a __global__ function template. " + "A way to create a function template is by adding 'template '."); + return -1; +#endif + } +}; + +// A runtime value for the actual compute architecture of a kernel. +// +// A single kernel can be compiled for several "virtual" compute architectures. +// When a program runs, the driver picks the version of the kernel that most +// closely matches the current hardware. This struct reflects the virtual +// compute architecture of the version of the kernel that the driver picks when +// the kernel runs. +struct SM_runtime { + friend SM_runtime kernel_runtime_arch(); +private: + const int _version; + SM_runtime(int version) + : _version (version) {} + +public: + __host__ __device__ int value() const { + return _version; + } +}; + +// Computes which compute architecture of a kernel will run +// +// Semantics are described above in the documentation of SM_runtime. +SM_runtime kernel_runtime_arch() { + auto kernel = inner::dummy_runtime_kernel; + cudaFuncAttributes attributes; + cudaFuncGetAttributes(&attributes, kernel); + + return SM_runtime(10 * attributes.ptxVersion); +} + +// SM_range represents a range of SM architectures. It can be used to +// conditionally compile a kernel. +template +struct SM_range { +private: + const SM_MIN _min; + const SM_MAX _max; +public: + __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) + : _min(min), _max(max) {} + + template + __host__ __device__ constexpr bool contains(SM_t current) const { + return _min.value() <= current.value() && current.value() < _max.value(); + } +}; + +} // namespace raft::arch From 72628613991f55c08e8e70e96c82ed050d84936b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 16:39:09 +0100 Subject: [PATCH 2/4] Fix style --- cpp/include/raft/distance/detail/distance.cuh | 39 +++++++++++--- .../detail/pairwise_matrix/dispatch.cuh | 47 ++++++++++------ .../detail/pairwise_matrix/kernel_sm60.cuh | 43 +++++++++------ cpp/include/raft/util/arch.cuh | 53 +++++++++---------- 4 files changed, 114 insertions(+), 68 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index da119b6a45..7ebc7b3414 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -45,8 +45,8 @@ #include #include -#include #include +#include #include namespace raft { @@ -262,9 +262,9 @@ void distance_impl(raft::resources const& handle, distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto runtime_arch = raft::arch::kernel_runtime_arch(); auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -276,8 +276,25 @@ void distance_impl(raft::resources const& handle, } else { // Else use "legacy" L2 ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); + distance_matrix_dispatch(distance_op, + m, + n, + k, + x, + y, + norm_A, + norm_B, + out, + fin_op, + stream, + is_row_major, + legacy_range); } } } @@ -531,9 +548,9 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto runtime_arch = raft::arch::kernel_runtime_arch(); auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -546,7 +563,13 @@ void distance_impl_l2_expanded( // NOTE: different name // Else use "legacy" L2. Compile *only* for architectures in the legacy // range. For newer architectures, compile empty kernels. ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( + distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 75680027d8..9bbeca1e90 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -17,9 +17,9 @@ #include "kernel_sm60.cuh" #include -#include #include #include +#include #include namespace raft::distance::detail { @@ -90,21 +90,22 @@ template > -void distance_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) +void distance_matrix_dispatch( + OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) { // Determine leading dimensions and, if column-major, flip order of passing x // and y. @@ -148,7 +149,21 @@ void distance_matrix_dispatch(OpT distance_op, typedef typename std::conditional::type Policy; return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream, sm_compat_range); + distance_op, + fin_op, + x, + y, + x_norm, + y_norm, + m, + n, + k, + ldx, + ldy, + ld_out, + out, + stream, + sm_compat_range); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 7404c47e66..3f6474deeb 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -16,9 +16,9 @@ #pragma once #include -#include #include #include +#include namespace raft::distance::detail { @@ -31,23 +31,24 @@ template -__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - opT distance_op, - FinOpT fin_op, - SM_compat_t sm_compat_range) +__global__ __launch_bounds__(Policy::Nthreads, + 2) void pairwise_matrix_kernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + opT distance_op, + FinOpT fin_op, + SM_compat_t sm_compat_range) { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { assert(false); return; } @@ -135,7 +136,15 @@ void pairwise_matrix(OpT distance_op, // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; + auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); kernel<<>>( diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 554805da9e..5103c2c591 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -43,26 +43,24 @@ namespace raft::arch { namespace inner { template struct SM_generic { -public: - __host__ __device__ constexpr int value() const { - return n; - } + public: + __host__ __device__ constexpr int value() const { return n; } }; // A __global__ inline void dummy_runtime_kernel() {} -} +} // namespace inner // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) // and SM_MIN and SM_FUTURE, that allow specifying an open interval of // compatible compute architectures. -using SM_min = inner::SM_generic<350>; -using SM_60 = inner::SM_generic<600>; -using SM_70 = inner::SM_generic<700>; -using SM_75 = inner::SM_generic<750>; -using SM_80 = inner::SM_generic<800>; -using SM_86 = inner::SM_generic<860>; -using SM_90 = inner::SM_generic<900>; +using SM_min = inner::SM_generic<350>; +using SM_60 = inner::SM_generic<600>; +using SM_70 = inner::SM_generic<700>; +using SM_75 = inner::SM_generic<750>; +using SM_80 = inner::SM_generic<800>; +using SM_86 = inner::SM_generic<860>; +using SM_90 = inner::SM_generic<900>; using SM_future = inner::SM_generic<99999>; // This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time @@ -70,7 +68,8 @@ using SM_future = inner::SM_generic<99999>; // i.e., inside a __global__ function template. struct SM_compute_arch { template - __host__ __device__ constexpr int value() const { + __host__ __device__ constexpr int value() const + { #ifdef __CUDA_ARCH__ return __CUDA_ARCH__; #else @@ -91,21 +90,20 @@ struct SM_compute_arch { // the kernel runs. struct SM_runtime { friend SM_runtime kernel_runtime_arch(); -private: + + private: const int _version; - SM_runtime(int version) - : _version (version) {} + SM_runtime(int version) : _version(version) {} -public: - __host__ __device__ int value() const { - return _version; - } + public: + __host__ __device__ int value() const { return _version; } }; // Computes which compute architecture of a kernel will run // // Semantics are described above in the documentation of SM_runtime. -SM_runtime kernel_runtime_arch() { +SM_runtime kernel_runtime_arch() +{ auto kernel = inner::dummy_runtime_kernel; cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel); @@ -117,17 +115,18 @@ SM_runtime kernel_runtime_arch() { // conditionally compile a kernel. template struct SM_range { -private: + private: const SM_MIN _min; const SM_MAX _max; -public: - __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) - : _min(min), _max(max) {} + + public: + __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) : _min(min), _max(max) {} template - __host__ __device__ constexpr bool contains(SM_t current) const { + __host__ __device__ constexpr bool contains(SM_t current) const + { return _min.value() <= current.value() && current.value() < _max.value(); } }; -} // namespace raft::arch +} // namespace raft::arch From 09a30501c00b36af296fe6513135c5b8b95a69a6 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 6 Mar 2023 16:55:28 +0100 Subject: [PATCH 3/4] Fix linker error: multiple definition.. --- cpp/include/raft/util/arch.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 5103c2c591..ef703a8486 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -47,7 +47,7 @@ struct SM_generic { __host__ __device__ constexpr int value() const { return n; } }; -// A +// A dummy kernel that is used to determine the runtime architecture. __global__ inline void dummy_runtime_kernel() {} } // namespace inner @@ -102,7 +102,7 @@ struct SM_runtime { // Computes which compute architecture of a kernel will run // // Semantics are described above in the documentation of SM_runtime. -SM_runtime kernel_runtime_arch() +inline SM_runtime kernel_runtime_arch() { auto kernel = inner::dummy_runtime_kernel; cudaFuncAttributes attributes; From 1a6636f994d3e6e48cd955ecd9ab59c159a54a1c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 7 Mar 2023 14:20:43 +0100 Subject: [PATCH 4/4] Implement review feedback --- cpp/include/raft/util/arch.cuh | 35 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index ef703a8486..dfa29334f5 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -38,9 +38,9 @@ namespace raft::arch { * compile-time architecture is in a specified compatibility range. */ -// inner::SM_generic is a template to create a generic compile-time SM +// detail::SM_generic is a template to create a generic compile-time SM // architecture constant. -namespace inner { +namespace detail { template struct SM_generic { public: @@ -49,30 +49,39 @@ struct SM_generic { // A dummy kernel that is used to determine the runtime architecture. __global__ inline void dummy_runtime_kernel() {} -} // namespace inner +} // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) // and SM_MIN and SM_FUTURE, that allow specifying an open interval of // compatible compute architectures. -using SM_min = inner::SM_generic<350>; -using SM_60 = inner::SM_generic<600>; -using SM_70 = inner::SM_generic<700>; -using SM_75 = inner::SM_generic<750>; -using SM_80 = inner::SM_generic<800>; -using SM_86 = inner::SM_generic<860>; -using SM_90 = inner::SM_generic<900>; -using SM_future = inner::SM_generic<99999>; +using SM_min = detail::SM_generic<350>; +using SM_60 = detail::SM_generic<600>; +using SM_70 = detail::SM_generic<700>; +using SM_75 = detail::SM_generic<750>; +using SM_80 = detail::SM_generic<800>; +using SM_86 = detail::SM_generic<860>; +using SM_90 = detail::SM_generic<900>; +using SM_future = detail::SM_generic<99999>; // This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time // compute architecture. It can only be used where __CUDA_ARCH__ is defined, // i.e., inside a __global__ function template. struct SM_compute_arch { template - __host__ __device__ constexpr int value() const + __device__ constexpr int value() const { #ifdef __CUDA_ARCH__ return __CUDA_ARCH__; #else + // This function should not be called in host code (because __CUDA_ARCH__ is + // not defined). This function is constexpr and thus can be called in host + // code (due to the --expt-relaxed-constexpr compile flag). We would like to + // provide an intelligible error message when this function is called in + // host code, which we do below. + // + // To make sure the static_assert only fires in host code, we use a dummy + // template parameter as described in P2593: + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html static_assert(dummy != 0, "SM_compute_arch.value() is only callable from a __global__ function template. " "A way to create a function template is by adding 'template '."); @@ -104,7 +113,7 @@ struct SM_runtime { // Semantics are described above in the documentation of SM_runtime. inline SM_runtime kernel_runtime_arch() { - auto kernel = inner::dummy_runtime_kernel; + auto kernel = detail::dummy_runtime_kernel; cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel);