Skip to content

Commit

Permalink
Merge pull request #121 from senior-zero/enh-main/github/sm90_scan_by…
Browse files Browse the repository at this point in the history
…_key

Tune scan by key for SM90
  • Loading branch information
gevtushenko authored Jun 28, 2023
2 parents db2c5ce + d9484d4 commit abec065
Show file tree
Hide file tree
Showing 2 changed files with 633 additions and 70 deletions.
83 changes: 13 additions & 70 deletions cub/cub/device/dispatch/dispatch_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@

#pragma once

#include <iterator>

#include <cub/agent/agent_scan_by_key.cuh>
#include <cub/config.cuh>
#include <cub/device/dispatch/dispatch_scan.cuh>
#include <cub/device/dispatch/tuning/tuning_scan_by_key.cuh>
#include <cub/thread/thread_operators.cuh>
#include <cub/util_debug.cuh>
#include <cub/util_deprecated.cuh>
Expand All @@ -46,6 +45,8 @@

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include <iterator>

CUB_NAMESPACE_BEGIN

/******************************************************************************
Expand Down Expand Up @@ -184,64 +185,6 @@ __global__ void DeviceScanByKeyInitKernel(
}
}

/******************************************************************************
* Policy
******************************************************************************/

template <typename KeysInputIteratorT,
typename AccumT>
struct DeviceScanByKeyPolicy
{
using KeyT = cub::detail::value_t<KeysInputIteratorT>;

static constexpr size_t MaxInputBytes = (cub::max)(sizeof(KeyT),
sizeof(AccumT));

static constexpr size_t CombinedInputBytes = sizeof(KeyT) + sizeof(AccumT);

// SM350
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6;
static constexpr int ITEMS_PER_THREAD =
((MaxInputBytes <= 8)
? 6
: Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD,
CombinedInputBytes));

using ScanByKeyPolicyT =
AgentScanByKeyPolicy<128,
ITEMS_PER_THREAD,
BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_CA,
BLOCK_SCAN_WARP_SCANS,
BLOCK_STORE_WARP_TRANSPOSE,
detail::default_reduce_by_key_delay_constructor_t<AccumT, int>>;
};

// SM520
struct Policy520 : ChainedPolicy<520, Policy520, Policy350>
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 9;
static constexpr int ITEMS_PER_THREAD =
((MaxInputBytes <= 8)
? 9
: Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD,
CombinedInputBytes));

using ScanByKeyPolicyT =
AgentScanByKeyPolicy<256,
ITEMS_PER_THREAD,
BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_CA,
BLOCK_SCAN_WARP_SCANS,
BLOCK_STORE_WARP_TRANSPOSE,
detail::default_reduce_by_key_delay_constructor_t<AccumT, int>>;
};

using MaxPolicy = Policy520;
};

/******************************************************************************
* Dispatch
******************************************************************************/
Expand Down Expand Up @@ -280,16 +223,16 @@ template <
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT =
detail::accumulator_t<
ScanOpT,
cub::detail::conditional_t<
std::is_same<InitValueT, NullType>::value,
cub::detail::value_t<ValuesInputIteratorT>,
InitValueT>,
cub::detail::value_t<ValuesInputIteratorT>>,
typename SelectedPolicy =
DeviceScanByKeyPolicy<KeysInputIteratorT, AccumT>>
typename AccumT =
detail::accumulator_t<ScanOpT,
cub::detail::conditional_t<std::is_same<InitValueT, NullType>::value,
cub::detail::value_t<ValuesInputIteratorT>,
InitValueT>,
cub::detail::value_t<ValuesInputIteratorT>>,
typename SelectedPolicy = DeviceScanByKeyPolicy<KeysInputIteratorT,
AccumT,
cub::detail::value_t<ValuesInputIteratorT>,
ScanOpT>>
struct DispatchScanByKey : SelectedPolicy
{
//---------------------------------------------------------------------
Expand Down
Loading

0 comments on commit abec065

Please sign in to comment.