Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor CUB util_device
Browse files Browse the repository at this point in the history
bernhardmgruber committed Jul 6, 2024

Verified

This commit was signed with the committer’s verified signature.
lucasssvaz Lucas Saavedra Vaz
1 parent 20cd6ce commit 9ead91f
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
@@ -155,9 +155,10 @@ CUB_RUNTIME_FUNCTION inline int DeviceCountUncached()

/**
* \brief Cache for an arbitrary value produced by a nullary function.
* deprecated [Since 2.6.0]
*/
template <typename T, T (*Function)()>
struct ValueCache
struct CUB_DEPRECATED ValueCache
{
T const value;

@@ -170,13 +171,11 @@ struct ValueCache
{}
};

// Host code, only safely usable in C++11 or newer, where thread-safe
// initialization of static locals is guaranteed. This is a separate function
// to avoid defining a local static in a host/device function.
// Host code. This is a separate function to avoid defining a local static in a host/device function.
_CCCL_HOST inline int DeviceCountCachedValue()
{
static ValueCache<int, DeviceCountUncached> cache;
return cache.value;
static int count = DeviceCountUncached();
return count;
}

/**
@@ -211,7 +210,7 @@ struct PerDeviceAttributeCache
// Each entry starts in the `DeviceEntryEmpty` state, then proceeds to the
// `DeviceEntryInitializing` state, and then proceeds to the
// `DeviceEntryReady` state. These are the only state transitions allowed;
// e.g. a linear sequence of transitions.
// i.e. a linear sequence of transitions.
enum DeviceEntryStatus
{
DeviceEntryEmpty = 0,
@@ -372,7 +371,6 @@ _CCCL_HOST inline cudaError_t PtxVersionUncached(int& ptx_version, int device)
template <typename Tag>
_CCCL_HOST inline PerDeviceAttributeCache& GetPerDeviceAttributeCache()
{
// C++11 guarantees that initialization of static locals is thread safe.
static PerDeviceAttributeCache cache;
return cache;
}
@@ -392,8 +390,7 @@ struct SmVersionCacheTag
_CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
{
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
// If this call fails, then we get the error code back in the payload,
// which we check with `CubDebug` below.
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
[=](int& pv) {
return PtxVersionUncached(pv, device);
},
@@ -417,23 +414,10 @@ _CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersion(int& ptx_version)
{
cudaError_t result = cudaErrorUnknown;
NV_IF_TARGET(
NV_IS_HOST,
(auto const device = CurrentDevice();
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
// If this call fails, then we get the error code back in the payload,
// which we check with `CubDebug` below.
[=](int& pv) {
return PtxVersionUncached(pv, device);
},
device);

if (!CubDebug(payload.error)) { ptx_version = payload.attribute; }

result = payload.error;),
( // NV_IS_DEVICE:
result = PtxVersionUncached(ptx_version);));

NV_IF_TARGET(NV_IS_HOST,
(result = PtxVersion(ptx_version, CurrentDevice());),
( // NV_IS_DEVICE:
result = PtxVersionUncached(ptx_version);));
return result;
}

@@ -477,8 +461,7 @@ CUB_RUNTIME_FUNCTION inline cudaError_t SmVersion(int& sm_version, int device =
NV_IF_TARGET(
NV_IS_HOST,
(auto const payload = GetPerDeviceAttributeCache<SmVersionCacheTag>()(
// If this call fails, then we get the error code back in
// the payload, which we check with `CubDebug` below.
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
[=](int& pv) {
return SmVersionUncached(pv, device);
},
@@ -565,9 +548,8 @@ CUB_RUNTIME_FUNCTION inline cudaError_t DebugSyncStream(cudaStream_t stream)
CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool& has_uva)
{
has_uva = false;
cudaError_t error = cudaSuccess;
int device = -1;
error = CubDebug(cudaGetDevice(&device));
cudaError_t error = CubDebug(cudaGetDevice(&device));
if (cudaSuccess != error)
{
return error;

0 comments on commit 9ead91f

Please sign in to comment.