Skip to content

Commit

Permalink
Clean up AliasTemporaries (#1815)
Browse files Browse the repository at this point in the history
* Allow const size arrays in AliasTemporaries
* Fix integer types used
  • Loading branch information
bernhardmgruber authored Jun 6, 2024
1 parent 3424dd9 commit d022a20
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions cub/cub/util_temporary_storage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ CUB_NAMESPACE_BEGIN
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

/**
* @brief Alias temporaries to externally-allocated device storage (or simply return the amount of
* storage needed).
* @brief Alias temporaries to externally-allocated device storage (or simply return the amount of storage needed).
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage.
Expand All @@ -73,18 +72,18 @@ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t AliasTemporaries(
void* d_temp_storage,
size_t& temp_storage_bytes,
void* (&allocations)[ALLOCATIONS],
size_t (&allocation_sizes)[ALLOCATIONS])
const size_t (&allocation_sizes)[ALLOCATIONS])
{
constexpr int ALIGN_BYTES = 256;
constexpr int ALIGN_MASK = ~(ALIGN_BYTES - 1);
constexpr size_t ALIGN_BYTES = 256;
constexpr size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);

// Compute exclusive prefix sum over allocation requests
size_t allocation_offsets[ALLOCATIONS];
size_t bytes_needed = 0;
for (int i = 0; i < ALLOCATIONS; ++i)
{
size_t allocation_bytes = (allocation_sizes[i] + ALIGN_BYTES - 1) & ALIGN_MASK;
allocation_offsets[i] = bytes_needed;
const size_t allocation_bytes = (allocation_sizes[i] + ALIGN_BYTES - 1) & ALIGN_MASK;
allocation_offsets[i] = bytes_needed;
bytes_needed += allocation_bytes;
}
bytes_needed += ALIGN_BYTES - 1;
Expand All @@ -103,7 +102,8 @@ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t AliasTemporaries(
}

// Alias
d_temp_storage = (void*) ((size_t(d_temp_storage) + ALIGN_BYTES - 1) & ALIGN_MASK);
d_temp_storage =
reinterpret_cast<void*>((reinterpret_cast<uintptr_t>(d_temp_storage) + ALIGN_BYTES - 1) & ALIGN_MASK);
for (int i = 0; i < ALLOCATIONS; ++i)
{
allocations[i] = static_cast<char*>(d_temp_storage) + allocation_offsets[i];
Expand Down

0 comments on commit d022a20

Please sign in to comment.