Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Faster Least Significant Digit Radix Sort Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
canonizer authored and alliepiper committed Oct 21, 2020
1 parent ea48955 commit 6405882
Show file tree
Hide file tree
Showing 8 changed files with 1,920 additions and 66 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.p4config
*~
\#*
67 changes: 30 additions & 37 deletions cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ namespace cub {
* Tuning policy types
******************************************************************************/

/**
* Radix ranking algorithm
*/
enum RadixRankAlgorithm
{
RADIX_RANK_BASIC,
RADIX_RANK_MEMOIZE,
RADIX_RANK_MATCH
};

/**
* Parameterizable tuning policy type for AgentRadixSortDownsweep
*/
Expand Down Expand Up @@ -137,6 +127,9 @@ struct AgentRadixSortDownsweep

RADIX_DIGITS = 1 << RADIX_BITS,
KEYS_ONLY = Equals<ValueT, NullType>::VALUE,
LOAD_WARP_STRIPED = RANK_ALGORITHM == RADIX_RANK_MATCH ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
};

// Input iterator wrapper type (for applying cache modifier)s
Expand All @@ -148,7 +141,15 @@ struct AgentRadixSortDownsweep
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE),
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH),
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY),
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR>
>::Type
>::Type
>::Type
>::Type BlockRadixRankT;

Expand Down Expand Up @@ -303,16 +304,15 @@ struct AgentRadixSortDownsweep
}

/**
* Load a tile of keys (specialized for full tile, any ranking algorithm)
* Load a tile of keys (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadKeysT(temp_storage.load_keys).Load(
d_keys_in + block_offset, keys);
Expand All @@ -322,16 +322,15 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for partial tile, any ranking algorithm)
* Load a tile of keys (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -345,30 +344,29 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for full tile, match ranking algorithm)
* Load a tile of keys (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys);
}


/**
* Load a tile of keys (specialized for partial tile, match ranking algorithm)
* Load a tile of keys (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -377,17 +375,15 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item);
}


/**
* Load a tile of values (specialized for full tile, any ranking algorithm)
* Load a tile of values (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadValuesT(temp_storage.load_values).Load(
d_values_in + block_offset, values);
Expand All @@ -397,15 +393,14 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of values (specialized for partial tile, any ranking algorithm)
* Load a tile of values (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -419,28 +414,27 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of items (specialized for full tile, match ranking algorithm)
* Load a tile of items (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values);
}


/**
* Load a tile of items (specialized for partial tile, match ranking algorithm)
* Load a tile of items (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -449,7 +443,6 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items);
}


/**
* Truck along associated values
*/
Expand All @@ -470,7 +463,7 @@ struct AgentRadixSortDownsweep
block_offset,
valid_items,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

ScatterValues<FULL_TILE>(
values,
Expand Down Expand Up @@ -515,7 +508,7 @@ struct AgentRadixSortDownsweep
valid_items,
default_key,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

// Twiddle key bits if necessary
#pragma unroll
Expand Down
Loading

0 comments on commit 6405882

Please sign in to comment.