Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

cub::ThreadLoadAsync and friends, abstractions for asynchronous data movement #209

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 184 additions & 20 deletions cub/thread/thread_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,20 @@ namespace cub {
*/
enum CacheLoadModifier
{
LOAD_DEFAULT, ///< Default (no modifier)
LOAD_CA, ///< Cache at all levels
LOAD_CG, ///< Cache at global level
LOAD_CS, ///< Cache streaming (likely to be accessed once)
LOAD_CV, ///< Cache as volatile (including cached system lines)
LOAD_LDG, ///< Cache as texture
LOAD_VOLATILE, ///< Volatile (any memory space)
LOAD_DEFAULT, ///< Default (no modifier)
LOAD_CA, ///< Cache at all levels
LOAD_CACHE_ALWAYS = LOAD_CA, ///< Cache at all levels
LOAD_CG, ///< Cache at global level
LOAD_CACHE_GLOBAL = LOAD_CG, ///< Cache at global level
LOAD_CS, ///< Cache streaming (likely to be accessed once)
LOAD_CACHE_STREAMING = LOAD_CS, ///< Cache streaming (likely to be accessed once)
LOAD_CV, ///< Cache as volatile (including cached system lines)
LOAD_CACHE_VOLATILE = LOAD_CV, ///< Cache as volatile (including cached system lines)
LOAD_LDG, ///< Cache as texture
LOAD_VOLATILE, ///< Volatile (any memory space)
LOAD_LAST_USE, ///< Indicate the line will not be used again
LOAD_WRITE_BACK, ///< Write back at all coherent levels
LOAD_WRITE_THROUGH ///< Write through to system memory
};


Expand All @@ -83,7 +90,7 @@ enum CacheLoadModifier
*
* // 32-bit load using cache-global modifier:
* int *d_in;
* int val = cub::ThreadLoad<cub::LOAD_CA>(d_in + threadIdx.x);
* int val = cub::ThreadLoad<cub::LOAD_CG>(d_in + threadIdx.x);
*
* // 16-bit load using default modifier
* short *d_in;
Expand Down Expand Up @@ -273,18 +280,10 @@ struct IterateThreadLoad<MAX, MAX>
/**
* Define powers-of-two ThreadLoad specializations for the various Cache load modifiers
*/
#if CUB_PTX_ARCH >= 200
_CUB_LOAD_ALL(LOAD_CA, ca)
_CUB_LOAD_ALL(LOAD_CG, cg)
_CUB_LOAD_ALL(LOAD_CS, cs)
_CUB_LOAD_ALL(LOAD_CV, cv)
#else
_CUB_LOAD_ALL(LOAD_CA, global)
// Use volatile to ensure coherent reads when this PTX is JIT'd to run on newer architectures with L1
_CUB_LOAD_ALL(LOAD_CG, volatile.global)
_CUB_LOAD_ALL(LOAD_CS, global)
_CUB_LOAD_ALL(LOAD_CV, volatile.global)
#endif
_CUB_LOAD_ALL(LOAD_CA, ca)
_CUB_LOAD_ALL(LOAD_CG, cg)
_CUB_LOAD_ALL(LOAD_CS, cs)
_CUB_LOAD_ALL(LOAD_CV, cv)

#if CUB_PTX_ARCH >= 350
_CUB_LOAD_ALL(LOAD_LDG, global.nc)
Expand Down Expand Up @@ -416,6 +415,171 @@ __device__ __forceinline__ typename std::iterator_traits<InputIteratorT>::value_
}


#endif // DOXYGEN_SHOULD_SKIP_THIS


/**
* \name Asynchronous Thread I/O (cache modified)
* @{
*/

/**
* \brief Thread utility for asynchronously reading memory using cub::CacheLoadModifier cache modifiers. Can be used to load any data type.
*
* \par Example
* \code
* #include <cub/cub.cuh> // or equivalently <cub/thread/thread_load.cuh>
*
* // 4x 32-bit load using cache-always modifier:
* int *d_in;
* __shared__ int val[...];
* cub::ThreadLoadAsync<cub::LOAD_CA>(d_in + 0 + threadIdx.x, val + 0);
* cub::ThreadLoadAsync<cub::LOAD_CA>(d_in + 1 + threadIdx.x, val + 1);
* cub::ThreadLoadAsync<cub::LOAD_CA>(d_in + 2 + threadIdx.x, val + 2);
* cub::ThreadLoadAsync<cub::LOAD_CA>(d_in + 3 + threadIdx.x, val + 3);
* cub::ThreadLoadWait();

* // 4x 128-bit load using cache-global modifier:
* int4 *d_in;
* __shared__ int4 val[...];
* cub::ThreadLoadAsync<cub::LOAD_CG>(d_in + 0 + threadIdx.x, val + 0);
* cub::ThreadLoadAsync<cub::LOAD_CG>(d_in + 1 + threadIdx.x, val + 1);
* cub::ThreadLoadAsync<cub::LOAD_CG>(d_in + 2 + threadIdx.x, val + 2);
* cub::ThreadLoadAsync<cub::LOAD_CG>(d_in + 3 + threadIdx.x, val + 3);
* cub::ThreadLoadCommit();
* // ...
* cub::ThreadLoadWait();
* \endcode
*
* \tparam MODIFIER <b>[inferred]</b> CacheLoadModifier enumeration
* \tparam InputIteratorT <b>[inferred]</b> Input iterator type \iterator
*/
template <
CacheLoadModifier MODIFIER,
typename InputIteratorT>
__device__ __forceinline__ void ThreadLoadAsync(ForwardIteratorT input, ForwardIteratorT output)


/**
* \brief Establishes an ordering w.r.t previously issued ThreadLoadAsync operations.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably have a comment saying that this means that prior operations have been read from the source, although they have not necessarily been stored to the destination.

*/
__device__ __forceinline__ void ThreadLoadCommit() {
#if CUB_PTX_ARCH >= 800
asm volatile("cp.async.commit_group;\n" ::);
#endif
}


/**
* \brief Blocks until all but N previous ThreadLoadAsync operations have completed.
*/
template <size_t N>
__device__ __forceinline__ void ThreadLoadWaitFor() {
#if CUB_PTX_ARCH >= 800
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}


/**
* \brief Blocks until all previous ThreadLoadAsync operations have completed.
*/
template <>
__device__ __forceinline__ void ThreadLoadWaitFor<0>() {
#if CUB_PTX_ARCH >= 800
asm volatile("cp.async.wait_all;\n" ::);
#endif
}

/**
* \brief Blocks until all previous ThreadLoadAsync operations have completed.
*/
__device__ __forceinline__ void ThreadLoadWait() {
ThreadLoadWaitFor<0>();
}

//@} end member group


#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

/**
* 16-byte LOAD_CACHE_GLOBAL async ThreadLoadAsync definition.
*/
template <typename ForwardIteratorT>
__device__ __forceinline__ void ThreadLoadAsync(
ForwardIteratorT input,
ForwardIteratorT output,
Int2Type<LOAD_CACHE_GLOBAL> /*modifier*/,
Int2Type<16> /*size*/,
Int2Type<true> /*is_pointer*/,
Int2Type<true> /*is_async*/)
{
if (__isGlobal(input) && __isShared(output))
{
asm volatile ("cp.async.ca.shared.global [%0], [%1], %2, %3;"
:
: "r"(__cvta_generic_to_shared(dst)), "l"(src), "n"(N), "n"(N)
: "memory");
}
else
*output = ThreadLoad<MODIFIER>(input);
}


/**
* Generic async ThreadLoadAsync definition.
*/
template <
CacheLoadModifier MODIFIER,
typename ForwardIteratorT,
size_t N>
__device__ __forceinline__ void ThreadLoadAsync(
ForwardIteratorT input,
ForwardIteratorT output,
Int2Type<MODIFIER> /*modifier*/,
Int2Type<N> /*size*/,
Int2Type<true> /*is_pointer*/,
Int2Type<true> /*is_async*/)
{
if (__isGlobal(input) && __isShared(output))
{
asm volatile ("cp.async.ca.shared.global [%0], [%1], %2, %3;"
:
: "r"(__cvta_generic_to_shared(dst)), "l"(src), "n"(N), "n"(N)
: "memory");
}
else
*output = ThreadLoad<MODIFIER>(input);
}


/**
* ThreadLoadAsync tag dispatcher.
*/
template <
CacheLoadModifier MODIFIER,
typename ForwardIteratorT>
__device__ __forceinline__ void ThreadLoadAsync(ForwardIteratorT input, ForwardIteratorT output)
{
// Apply tags for partial-specialization.
ThreadLoadAsync(
input,
output,
// The cache modifier.
Int2Type<MODIFIER>(),
// The size of the value type.
Int2Type<sizeof(typename std::iterator_traits<ForwardIteratorT>::value_type)>,
// Is the iterator a pointer?
Int2Type<IsPointer<ForwardIteratorT>::VALUE>(),
// Can it actually be async?
Int2Type<
sizeof(typename std::iterator_traits<ForwardIteratorT>::value_type) == 4
|| sizeof(typename std::iterator_traits<ForwardIteratorT>::value_type) == 8
|| sizeof(typename std::iterator_traits<ForwardIteratorT>::value_type) == 16
>());
}


#endif // DOXYGEN_SHOULD_SKIP_THIS

Expand Down