Skip to content

Commit

Permalink
Refactor CUB util_device
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jul 16, 2024
1 parent 99e7d60 commit 0760911
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ CUB_RUNTIME_FUNCTION inline int DeviceCountUncached()

/**
* \brief Cache for an arbitrary value produced by a nullary function.
* deprecated [Since 2.6.0]
*/
template <typename T, T (*Function)()>
struct ValueCache
struct CUB_DEPRECATED ValueCache
{
T const value;

Expand All @@ -170,13 +171,11 @@ struct ValueCache
{}
};

// Host code, only safely usable in C++11 or newer, where thread-safe
// initialization of static locals is guaranteed. This is a separate function
// to avoid defining a local static in a host/device function.
// Host code. This is a separate function to avoid defining a local static in a host/device function.
_CCCL_HOST inline int DeviceCountCachedValue()
{
static ValueCache<int, DeviceCountUncached> cache;
return cache.value;
static int count = DeviceCountUncached();
return count;
}

/**
Expand Down Expand Up @@ -211,7 +210,7 @@ struct PerDeviceAttributeCache
// Each entry starts in the `DeviceEntryEmpty` state, then proceeds to the
// `DeviceEntryInitializing` state, and then proceeds to the
// `DeviceEntryReady` state. These are the only state transitions allowed;
// e.g. a linear sequence of transitions.
// i.e. a linear sequence of transitions.
enum DeviceEntryStatus
{
DeviceEntryEmpty = 0,
Expand Down Expand Up @@ -372,7 +371,6 @@ _CCCL_HOST inline cudaError_t PtxVersionUncached(int& ptx_version, int device)
template <typename Tag>
_CCCL_HOST inline PerDeviceAttributeCache& GetPerDeviceAttributeCache()
{
// C++11 guarantees that initialization of static locals is thread safe.
static PerDeviceAttributeCache cache;
return cache;
}
Expand All @@ -392,8 +390,7 @@ struct SmVersionCacheTag
_CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
{
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
// If this call fails, then we get the error code back in the payload,
// which we check with `CubDebug` below.
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
[=](int& pv) {
return PtxVersionUncached(pv, device);
},
Expand All @@ -417,23 +414,10 @@ _CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersion(int& ptx_version)
{
cudaError_t result = cudaErrorUnknown;
NV_IF_TARGET(
NV_IS_HOST,
(auto const device = CurrentDevice();
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
// If this call fails, then we get the error code back in the payload,
// which we check with `CubDebug` below.
[=](int& pv) {
return PtxVersionUncached(pv, device);
},
device);

if (!CubDebug(payload.error)) { ptx_version = payload.attribute; }

result = payload.error;),
( // NV_IS_DEVICE:
result = PtxVersionUncached(ptx_version);));

NV_IF_TARGET(NV_IS_HOST,
(result = PtxVersion(ptx_version, CurrentDevice());),
( // NV_IS_DEVICE:
result = PtxVersionUncached(ptx_version);));
return result;
}

Expand Down Expand Up @@ -477,8 +461,7 @@ CUB_RUNTIME_FUNCTION inline cudaError_t SmVersion(int& sm_version, int device =
NV_IF_TARGET(
NV_IS_HOST,
(auto const payload = GetPerDeviceAttributeCache<SmVersionCacheTag>()(
// If this call fails, then we get the error code back in
// the payload, which we check with `CubDebug` below.
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
[=](int& pv) {
return SmVersionUncached(pv, device);
},
Expand Down Expand Up @@ -565,9 +548,8 @@ CUB_RUNTIME_FUNCTION inline cudaError_t DebugSyncStream(cudaStream_t stream)
CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool& has_uva)
{
has_uva = false;
cudaError_t error = cudaSuccess;
int device = -1;
error = CubDebug(cudaGetDevice(&device));
cudaError_t error = CubDebug(cudaGetDevice(&device));
if (cudaSuccess != error)
{
return error;
Expand Down

0 comments on commit 0760911

Please sign in to comment.