diff --git a/cub/cub/device/dispatch/dispatch_scan_by_key.cuh b/cub/cub/device/dispatch/dispatch_scan_by_key.cuh index cfc499a35d0..2139721f734 100644 --- a/cub/cub/device/dispatch/dispatch_scan_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_scan_by_key.cuh @@ -33,11 +33,10 @@ #pragma once -#include - #include #include #include +#include #include #include #include @@ -46,6 +45,8 @@ #include +#include + CUB_NAMESPACE_BEGIN /****************************************************************************** @@ -184,64 +185,6 @@ __global__ void DeviceScanByKeyInitKernel( } } -/****************************************************************************** - * Policy - ******************************************************************************/ - -template -struct DeviceScanByKeyPolicy -{ - using KeyT = cub::detail::value_t; - - static constexpr size_t MaxInputBytes = (cub::max)(sizeof(KeyT), - sizeof(AccumT)); - - static constexpr size_t CombinedInputBytes = sizeof(KeyT) + sizeof(AccumT); - - // SM350 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> - { - static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6; - static constexpr int ITEMS_PER_THREAD = - ((MaxInputBytes <= 8) - ? 6 - : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, - CombinedInputBytes)); - - using ScanByKeyPolicyT = - AgentScanByKeyPolicy<128, - ITEMS_PER_THREAD, - BLOCK_LOAD_WARP_TRANSPOSE, - LOAD_CA, - BLOCK_SCAN_WARP_SCANS, - BLOCK_STORE_WARP_TRANSPOSE, - detail::default_reduce_by_key_delay_constructor_t>; - }; - - // SM520 - struct Policy520 : ChainedPolicy<520, Policy520, Policy350> - { - static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 9; - static constexpr int ITEMS_PER_THREAD = - ((MaxInputBytes <= 8) - ? 9 - : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, - CombinedInputBytes)); - - using ScanByKeyPolicyT = - AgentScanByKeyPolicy<256, - ITEMS_PER_THREAD, - BLOCK_LOAD_WARP_TRANSPOSE, - LOAD_CA, - BLOCK_SCAN_WARP_SCANS, - BLOCK_STORE_WARP_TRANSPOSE, - detail::default_reduce_by_key_delay_constructor_t>; - }; - - using MaxPolicy = Policy520; -}; - /****************************************************************************** * Dispatch ******************************************************************************/ @@ -280,16 +223,16 @@ template < typename ScanOpT, typename InitValueT, typename OffsetT, - typename AccumT = - detail::accumulator_t< - ScanOpT, - cub::detail::conditional_t< - std::is_same::value, - cub::detail::value_t, - InitValueT>, - cub::detail::value_t>, - typename SelectedPolicy = - DeviceScanByKeyPolicy> + typename AccumT = + detail::accumulator_t::value, + cub::detail::value_t, + InitValueT>, + cub::detail::value_t>, + typename SelectedPolicy = DeviceScanByKeyPolicy, + ScanOpT>> struct DispatchScanByKey : SelectedPolicy { //--------------------------------------------------------------------- diff --git a/cub/cub/device/dispatch/tuning/tuning_scan_by_key.cuh b/cub/cub/device/dispatch/tuning/tuning_scan_by_key.cuh new file mode 100644 index 00000000000..68faf6b3ab3 --- /dev/null +++ b/cub/cub/device/dispatch/tuning/tuning_scan_by_key.cuh @@ -0,0 +1,620 @@ +/****************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +namespace detail +{ +namespace scan_by_key +{ + +enum class primitive_accum { no, yes }; +enum class primitive_op { no, yes }; +enum class offset_size { _4, _8, unknown }; +enum class val_size { _1, _2, _4, _8, _16, unknown }; +enum class key_size { _1, _2, _4, _8, _16, unknown }; + +template +constexpr primitive_accum is_primitive_accum() +{ + return Traits::PRIMITIVE ? primitive_accum::yes : primitive_accum::no; +} + +template +constexpr primitive_op is_primitive_op() +{ + return basic_binary_op_t::value ? primitive_op::yes : primitive_op::no; +} + +template +constexpr val_size classify_val_size() +{ + return sizeof(ValueT) == 1 ? val_size::_1 + : sizeof(ValueT) == 2 ? val_size::_2 + : sizeof(ValueT) == 4 ? val_size::_4 + : sizeof(ValueT) == 8 ? val_size::_8 + : sizeof(ValueT) == 16 ? val_size::_16 + : val_size::unknown; +} + +template +constexpr key_size classify_key_size() +{ + return sizeof(KeyT) == 1 ? key_size::_1 + : sizeof(KeyT) == 2 ? key_size::_2 + : sizeof(KeyT) == 4 ? key_size::_4 + : sizeof(KeyT) == 8 ? key_size::_8 + : sizeof(KeyT) == 16 ? key_size::_16 + : key_size::unknown; +} + +template (), + val_size AccumSize = classify_val_size(), + primitive_accum PrimitiveAccumulator = is_primitive_accum()> +struct sm90_tuning +{ + static constexpr int nominal_4b_items_per_thread = 9; + + static constexpr int threads = 256; + + static constexpr size_t max_input_bytes = (cub::max)(sizeof(KeyT), sizeof(AccumT)); + + static constexpr size_t combined_input_bytes = sizeof(KeyT) + sizeof(AccumT); + + static constexpr int items = + ((max_input_bytes <= 8) + ? 9 + : Nominal4BItemsToItemsCombined(nominal_4b_items_per_thread, combined_input_bytes)); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::default_reduce_by_key_delay_constructor_t; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<650>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + + static constexpr int items = 16; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<124, 995>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<488, 545>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<488, 1070>; +}; + +#if CUB_IS_INT128_ENABLED +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<936, 1105>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<936, 1105>; +}; +#endif + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<136, 785>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<445>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<312, 865>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<352, 1170>; +}; + +#if CUB_IS_INT128_ENABLED +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 1190>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 1190>; +}; +#endif + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<850>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<128, 965>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 288; + + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<700, 1005>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<556, 1195>; +}; + +#if CUB_IS_INT128_ENABLED +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<512, 1030>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<512, 1030>; +}; +#endif + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 1010>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<420, 970>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 192; + + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<500, 1125>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<600, 930>; +}; + +#if CUB_IS_INT128_ENABLED +template +struct sm90_tuning +{ + static constexpr int threads = 192; + + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<364, 1085>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 192; + + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<364, 1085>; +}; +#endif + +template +struct sm90_tuning +{ + static constexpr int threads = 192; + + static constexpr int items = 7; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<500, 975>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<164, 1075>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + + static constexpr int items = 9; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<268, 1120>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 192; + + static constexpr int items = 9; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<320, 1200>; +}; + +#if CUB_IS_INT128_ENABLED +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<364, 1050>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int items = 23; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<364, 1050>; +}; +#endif +} // namespace scan_by_key +} // namespace detail + + +template +struct DeviceScanByKeyPolicy +{ + using KeyT = cub::detail::value_t; + + static constexpr size_t MaxInputBytes = (cub::max)(sizeof(KeyT), + sizeof(AccumT)); + + static constexpr size_t CombinedInputBytes = sizeof(KeyT) + sizeof(AccumT); + + // SM350 + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6; + static constexpr int ITEMS_PER_THREAD = + ((MaxInputBytes <= 8) + ? 6 + : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, + CombinedInputBytes)); + + using ScanByKeyPolicyT = + AgentScanByKeyPolicy<128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_CA, + BLOCK_SCAN_WARP_SCANS, + BLOCK_STORE_WARP_TRANSPOSE, + detail::default_reduce_by_key_delay_constructor_t>; + }; + + // SM520 + struct Policy520 : ChainedPolicy<520, Policy520, Policy350> + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 9; + static constexpr int ITEMS_PER_THREAD = + ((MaxInputBytes <= 8) + ? 9 + : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, + CombinedInputBytes)); + + using ScanByKeyPolicyT = + AgentScanByKeyPolicy<256, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_CA, + BLOCK_SCAN_WARP_SCANS, + BLOCK_STORE_WARP_TRANSPOSE, + detail::default_reduce_by_key_delay_constructor_t>; + }; + + // SM900 + struct Policy900 : ChainedPolicy<900, Policy900, Policy520> + { + using tuning = + detail::scan_by_key::sm90_tuning()>; + + using ScanByKeyPolicyT = AgentScanByKeyPolicy; + }; + + using MaxPolicy = Policy900; +}; + + +CUB_NAMESPACE_END