Skip to content

Commit

Permalink
Prune CUB's ChainedPolicy by __CUDA_ARCH_LIST__
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Aug 14, 2024
1 parent cbce14b commit f6522ab
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 1 deletion.
72 changes: 71 additions & 1 deletion cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ struct SmVersionCacheTag
{};

/**
* \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10).
* \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). This value
* must be one of __CUDA_ARCH_LIST__.
*
* \note This function may cache the result internally.
* \note This function is thread safe.
Expand Down Expand Up @@ -635,18 +636,74 @@ struct ChainedPolicy
template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op)
{
#ifdef __CUDA_ARCH_LIST__
return runtime_to_compiletime<__CUDA_ARCH_LIST__>(device_ptx_version, op);
#else
if (device_ptx_version < PolicyPtxVersion)
{
return PrevPolicyT::Invoke(device_ptx_version, op);
}
return op.template Invoke<PolicyT>();
#endif
}

private:
template <int, typename, typename>
friend struct ChainedPolicy; // let us call invoke_static of other ChainedPolicy instantiations

template <int... CudaArches, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t runtime_to_compiletime(int device_ptx_version, FunctorT& op)
{
// we instantiate invoke_static for each CudaArches, but only call the one matching device_ptx_version
cudaError_t e = cudaSuccess;
const cudaError_t dummy[] = {
(device_ptx_version == CudaArches ? (e = invoke_static<CudaArches>(op, ::cuda::std::true_type{}))
: cudaSuccess)...};
(void) dummy;
return e;
}

template <int DevicePtxVersion, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type)
{
// TODO(bgruber): drop boolean tag dispatches in C++17, since _CCCL_IF_CONSTEXPR will discard this branch properly
_CCCL_IF_CONSTEXPR (DevicePtxVersion < PolicyPtxVersion)
{
return PrevPolicyT::template invoke_static<DevicePtxVersion>(
op, ::cuda::std::bool_constant<(DevicePtxVersion < PolicyPtxVersion)>{});
}
else
{
return DoInvoke(op, ::cuda::std::bool_constant<DevicePtxVersion >= PolicyPtxVersion>{});
}
}

template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT& op, ::cuda::std::true_type)
{
return op.template Invoke<PolicyT>();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}
};

/// Helper for dispatching into a policy chain (end-of-chain specialization)
template <int PTX_VERSION, typename PolicyT>
struct ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>
{
template <int, typename, typename>
friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static

/// The policy for the active compiler pass
using ActivePolicy = PolicyT;

Expand All @@ -656,6 +713,19 @@ struct ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>
{
return op.template Invoke<PolicyT>();
}

private:
template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type)
{
return op.template Invoke<PolicyT>();
}

template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}
};

CUB_NAMESPACE_END
72 changes: 72 additions & 0 deletions cub/test/catch2_test_util_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,75 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for",
REQUIRE(ptx_version == kernel_cuda_arch);
REQUIRE(host_ptx_version == kernel_cuda_arch);
}

#ifdef __CUDA_ARCH_LIST__
CUB_TEST("PtxVersion returns a value from __CUDA_ARCH_LIST__", "[util][dispatch]")
{
int ptx_version = 0;
cub::PtxVersion(ptx_version);
const auto arch_list = std::vector<int>{__CUDA_ARCH_LIST__};
REQUIRE(std::find(arch_list.begin(), arch_list.end(), ptx_version) != arch_list.end());
}
#endif

#ifdef __CUDA_ARCH_LIST__
// We list policies for all virtual architectures that __CUDA_ARCH_LIST__ can contain, so the actual architectures the
// tests are compiled for should match to one of those
struct policy_hub
{
# define GEN_POLICY(cur, prev) \
struct policy##cur : cub::ChainedPolicy<cur, policy##cur, policy##prev> \
{ \
static constexpr int value = cur; \
}
// for the list of supported architectures, see libcudacxx/include/nv/target
GEN_POLICY(350, 350);
GEN_POLICY(370, 350);
GEN_POLICY(500, 370);
GEN_POLICY(520, 500);
GEN_POLICY(530, 520);
GEN_POLICY(600, 530);
GEN_POLICY(610, 600);
GEN_POLICY(620, 610);
GEN_POLICY(700, 620);
GEN_POLICY(720, 700);
GEN_POLICY(750, 720);
GEN_POLICY(800, 750);
GEN_POLICY(860, 800);
GEN_POLICY(870, 860);
GEN_POLICY(890, 870);
GEN_POLICY(900, 890);
GEN_POLICY(1000, 900);
// add more policies here when new architectures emerge
GEN_POLICY(2000, 1000); // non-existing architecture, just to test pruning
# undef GEN_POLICY

using max_policy = policy2000;
};

// Check that selected is one of arches
template <int selected, int... arch_list>
struct check
{
static_assert(::cuda::std::_Or<::cuda::std::bool_constant<selected == arch_list>...>::value, "");
using type = cudaError_t;
};

struct Closure
{
// We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__
template <typename ActivePolicy>
_CCCL_HOST_DEVICE auto Invoke() const -> typename check<ActivePolicy::value, __CUDA_ARCH_LIST__>::type
{
return cudaSuccess;
}
};

CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]")
{
int ptx_version = 0;
cub::PtxVersion(ptx_version);
Closure c;
policy_hub::max_policy::Invoke(ptx_version, c);
}
#endif

0 comments on commit f6522ab

Please sign in to comment.