diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index 714aa014ceb..369ed291933 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -358,7 +358,8 @@ struct SmVersionCacheTag {}; /** - * \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). + * \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). If + * __CUDA_ARCH_LIST__ is defined, this value is one of __CUDA_ARCH_LIST__. * * \note This function may cache the result internally. * \note This function is thread safe. @@ -635,11 +636,69 @@ struct ChainedPolicy template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op) { + // __CUDA_ARCH_LIST__ is only available from CTK 11.5 onwards +#ifdef __CUDA_ARCH_LIST__ + return runtime_to_compiletime<__CUDA_ARCH_LIST__>(device_ptx_version, op); +#else if (device_ptx_version < PolicyPtxVersion) { return PrevPolicyT::Invoke(device_ptx_version, op); } return op.template Invoke(); +#endif + } + +private: + template + friend struct ChainedPolicy; // let us call invoke_static of other ChainedPolicy instantiations + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t runtime_to_compiletime(int device_ptx_version, FunctorT& op) + { + // we instantiate invoke_static for each CudaArches, but only call the one matching device_ptx_version + cudaError_t e = cudaSuccess; + const cudaError_t dummy[] = { + (device_ptx_version == CudaArches ? (e = invoke_static(op, ::cuda::std::true_type{})) + : cudaSuccess)...}; + (void) dummy; + return e; + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type) + { + // TODO(bgruber): drop diagnostic suppression in C++17 + _CCCL_DIAG_PUSH + _CCCL_DIAG_SUPPRESS_MSVC(4127) // suppress Conditional Expression is Constant + _CCCL_IF_CONSTEXPR (DevicePtxVersion < PolicyPtxVersion) + { + // TODO(bgruber): drop boolean tag dispatches in C++17, since _CCCL_IF_CONSTEXPR will discard this branch properly + return PrevPolicyT::template invoke_static( + op, ::cuda::std::bool_constant<(DevicePtxVersion < PolicyPtxVersion)>{}); + } + else + { + return do_invoke(op, ::cuda::std::bool_constant= PolicyPtxVersion>{}); + } + _CCCL_DIAG_POP + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t do_invoke(FunctorT& op, ::cuda::std::true_type) + { + return op.template Invoke(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t do_invoke(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); } }; @@ -647,6 +706,9 @@ struct ChainedPolicy template struct ChainedPolicy { + template + friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static + /// The policy for the active compiler pass using ActivePolicy = PolicyT; @@ -656,6 +718,19 @@ struct ChainedPolicy { return op.template Invoke(); } + +private: + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type) + { + return op.template Invoke(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); + } }; CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_util_device.cu b/cub/test/catch2_test_util_device.cu index c59c076ec50..edd9348ad96 100644 --- a/cub/test/catch2_test_util_device.cu +++ b/cub/test/catch2_test_util_device.cu @@ -32,6 +32,9 @@ #include #include +#include +#include + #include "catch2_test_helper.h" #include "catch2_test_launch_helper.h" @@ -87,3 +90,163 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for", REQUIRE(ptx_version == kernel_cuda_arch); REQUIRE(host_ptx_version == kernel_cuda_arch); } + +#ifdef __CUDA_ARCH_LIST__ +CUB_TEST("PtxVersion returns a value from __CUDA_ARCH_LIST__", "[util][dispatch]") +{ + int ptx_version = 0; + cub::PtxVersion(ptx_version); + const auto arch_list = std::vector{__CUDA_ARCH_LIST__}; + REQUIRE(std::find(arch_list.begin(), arch_list.end(), ptx_version) != arch_list.end()); +} +#endif + +#ifdef __CUDA_ARCH_LIST__ +// We list policies for all virtual architectures that __CUDA_ARCH_LIST__ can contain, so the actual architectures the +// tests are compiled for should match to one of those +struct policy_hub_all +{ +# define GEN_POLICY(cur, prev) \ + struct policy##cur : cub::ChainedPolicy \ + { \ + static constexpr int value = cur; \ + } + // for the list of supported architectures, see libcudacxx/include/nv/target + GEN_POLICY(350, 350); + GEN_POLICY(370, 350); + GEN_POLICY(500, 370); + GEN_POLICY(520, 500); + GEN_POLICY(530, 520); + GEN_POLICY(600, 530); + GEN_POLICY(610, 600); + GEN_POLICY(620, 610); + GEN_POLICY(700, 620); + GEN_POLICY(720, 700); + GEN_POLICY(750, 720); + GEN_POLICY(800, 750); + GEN_POLICY(860, 800); + GEN_POLICY(870, 860); + GEN_POLICY(890, 870); + GEN_POLICY(900, 890); + GEN_POLICY(1000, 900); + // add more policies here when new architectures emerge + GEN_POLICY(2000, 1000); // non-existing architecture, just to test pruning + + using max_policy = policy2000; +}; + +// Check that selected is one of arches +template +struct check +{ + static_assert(::cuda::std::_Or<::cuda::std::bool_constant...>::value, ""); + using type = cudaError_t; +}; + +struct closure_all +{ + int ptx_version; + + // We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__ + template + CUB_RUNTIME_FUNCTION auto Invoke() const -> typename check::type + { + // policy_hub_all must list all PTX virtual architectures, so we can do an exact comparison here +# if TEST_LAUNCH == 0 + REQUIRE(+ActivePolicy::value == ptx_version); +# endif // TEST_LAUNCH == 0 + return +ActivePolicy::value == ptx_version ? cudaSuccess : cudaErrorInvalidValue; + } +}; + +CUB_RUNTIME_FUNCTION cudaError_t +check_chained_policy_prunes_to_arch_list(void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t = 0) +{ + if (d_temp_storage == nullptr) + { + temp_storage_bytes = 1; + return cudaSuccess; + } + int ptx_version = 0; + cub::PtxVersion(ptx_version); + closure_all c{ptx_version}; + return policy_hub_all::max_policy::Invoke(ptx_version, c); +} + +DECLARE_LAUNCH_WRAPPER(check_chained_policy_prunes_to_arch_list, check_wrapper_all); + +CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]") +{ + check_wrapper_all(); +} +#endif + +template +struct check_policy_closure +{ + int ptx_version; + ::cuda::std::array policies; + + template + CUB_RUNTIME_FUNCTION cudaError_t Invoke() const + { +#define CHECK_EXPR +ActivePolicy::value == ::cuda::std::lower_bound(policies.begin(), policies.end(), ptx_version)[-1] +#if TEST_LAUNCH == 0 + CAPTURE(ptx_version, policies); + REQUIRE(CHECK_EXPR); +#endif // TEST_LAUNCH == 0 + return CHECK_EXPR ? cudaSuccess : cudaErrorInvalidValue; +#undef CHECK_EXPR + } +}; + +template +CUB_RUNTIME_FUNCTION cudaError_t check_chained_policy_selects_correct_policy( + void* d_temp_storage, size_t& temp_storage_bytes, ::cuda::std::array policies, cudaStream_t = 0) +{ + if (d_temp_storage == nullptr) + { + temp_storage_bytes = 1; + return cudaSuccess; + } + int ptx_version = 0; + cub::PtxVersion(ptx_version); + check_policy_closure c{ptx_version, std::move(policies)}; + return PolicyHub::max_policy::Invoke(ptx_version, c); +} + +DECLARE_TMPL_LAUNCH_WRAPPER(check_chained_policy_selects_correct_policy, + check_wrapper_some, + ESCAPE_LIST(typename PolicyHub, int NumPolicies), + ESCAPE_LIST(PolicyHub, NumPolicies)); + +struct policy_hub_some +{ + GEN_POLICY(350, 350); + GEN_POLICY(500, 350); + GEN_POLICY(700, 500); + GEN_POLICY(900, 700); + GEN_POLICY(2000, 900); // non-existing architecture, just to test + using max_policy = policy2000; +}; + +struct policy_hub_few +{ + GEN_POLICY(350, 350); + GEN_POLICY(600, 350); + GEN_POLICY(2000, 600); // non-existing architecture, just to test + using max_policy = policy2000; +}; + +struct policy_hub_minimal +{ + GEN_POLICY(350, 350); + using max_policy = policy350; +}; + +CUB_TEST("ChainedPolicy invokes correct policy", "[util][dispatch]") +{ + check_wrapper_some(::cuda::std::array{350, 500, 700, 900, 2000}); + check_wrapper_some(::cuda::std::array{350, 600, 2000}); + check_wrapper_some(::cuda::std::array{350}); +}