diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index df77bfb6..5f301823 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -206,6 +206,7 @@ test:rocm_install: .nvcc: extends: - .deps:nvcc + - .gpus:nvcc-gpus - .deps:cmake-minimum before_script: - !reference [".deps:nvcc", before_script] @@ -220,6 +221,7 @@ build:nvcc: -D CMAKE_BUILD_TYPE=Release -D BUILD_TEST=ON -D BUILD_EXAMPLE=ON + -D NVGPU_TARGETS="$GPU_TARGETS" -B build -S . - cmake --build build @@ -251,6 +253,7 @@ build:nvcc-benchmark: -D BUILD_BENCHMARK=ON -D CMAKE_CXX_COMPILER=g++-8 -D CMAKE_C_COMPILER=gcc-8 + -D NVGPU_TARGETS="$GPU_TARGETS" -B build -S . - cmake --build build diff --git a/CHANGELOG.md b/CHANGELOG.md index f45af54d..9b1e74b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,21 @@ # Change Log for hipCUB See README.md on how to build the hipCUB documentation using Doxygen. -## (Unreleased) hipCUB-2.11.1 for ROCm 5.2.0 + +## (Unreleased) hipCUB-2.12.0 for ROCm 5.2.0 ### Added -- Packages for tests and benchmark executable on all supported OSes using CPack. -## (Unreleased) hipCUB-2.11.0 for ROCm 5.1.0 +- UniqueByKey device algorithm +- SubtractLeft, SubtractLeftPartialTile, SubtractRight, SubtractRightPartialTile overloads in BlockAdjacentDifference. + - The old overloads (FlagHeads, FlagTails, FlagHeadsAndTails) are deprecated. +- DeviceAdjacentDifference algorithm. +### Changed +- Obsolated type traits defined in util_type.hpp. Use the standard library equivalents instead. +- CUB backend references CUB and thrust version 1.16.0. +- DeviceRadixSort's num_items parameter's type is now templated instead of being an int. + - If an integral type with a size at most 4 bytes is passed (i.e. an int), the former logic applies. + - Otherwise the algorithm uses a larger indexing type that makes it possible to sort input data over 2**32 elements. + +## (Released) hipCUB-2.11.0 for ROCm 5.1.0 ### Added - Device segmented sort - Warp merge sort, WarpMask and thread sort from cub 1.15.0 supported in hipCUB diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 1c5b5738..ac61baf1 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -68,6 +68,7 @@ endfunction() # **************************************************************************** # Benchmarks # **************************************************************************** +add_hipcub_benchmark(benchmark_block_adjacent_difference.cpp) add_hipcub_benchmark(benchmark_block_discontinuity.cpp) add_hipcub_benchmark(benchmark_block_exchange.cpp) add_hipcub_benchmark(benchmark_block_histogram.cpp) @@ -76,6 +77,7 @@ add_hipcub_benchmark(benchmark_block_radix_sort.cpp) add_hipcub_benchmark(benchmark_block_reduce.cpp) add_hipcub_benchmark(benchmark_block_run_length_decode.cpp) add_hipcub_benchmark(benchmark_block_scan.cpp) +add_hipcub_benchmark(benchmark_device_adjacent_difference.cpp) add_hipcub_benchmark(benchmark_device_histogram.cpp) add_hipcub_benchmark(benchmark_device_partition.cpp) add_hipcub_benchmark(benchmark_device_radix_sort.cpp) diff --git a/benchmark/benchmark_block_adjacent_difference.cpp b/benchmark/benchmark_block_adjacent_difference.cpp new file mode 100644 index 00000000..b14d0cb1 --- /dev/null +++ b/benchmark/benchmark_block_adjacent_difference.cpp @@ -0,0 +1,420 @@ +// MIT License +// +// Copyright (c) 2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "common_benchmark_header.hpp" + +// HIP API +#include "hipcub/block/block_adjacent_difference.hpp" + +#include "hipcub/block/block_load.hpp" +#include "hipcub/block/block_store.hpp" + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 128; +#endif + +template < + class Benchmark, + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool WithTile, + typename... Args +> +__global__ +__launch_bounds__(BlockSize) +void kernel(Args ...args) +{ + Benchmark::template run(args...); +} + +template +struct minus +{ + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a - b; + } +}; + +struct subtract_left +{ + template + __device__ static void run(const T* d_input, T* d_output, unsigned int trials) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + + T input[ItemsPerThread]; + hipcub::LoadDirectStriped(lid, d_input + block_offset, input); + + hipcub::BlockAdjacentDifference adjacent_difference; + + #pragma nounroll + for(unsigned int trial = 0; trial < trials; trial++) + { + T output[ItemsPerThread]; + if(WithTile) + { + adjacent_difference.SubtractLeft(input, output, minus{}, T(123)); + } + else + { + adjacent_difference.SubtractLeft(input, output, minus{}); + } + + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + input[i] += output[i]; + } + + __syncthreads(); + } + + hipcub::StoreDirectStriped(lid, d_output + block_offset, input); + } +}; + +struct subtract_left_partial_tile +{ + template + __device__ static void run(const T* d_input, int* tile_sizes, T* d_output, unsigned int trials) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + + T input[ItemsPerThread]; + hipcub::LoadDirectStriped(lid, d_input + block_offset, input); + + hipcub::BlockAdjacentDifference adjacent_difference; + + int tile_size = tile_sizes[blockIdx.x]; + + // Try to evenly distribute the length of tile_sizes between all the trials + const auto tile_size_diff = (BlockSize * ItemsPerThread) / trials + 1; + + #pragma nounroll + for(unsigned int trial = 0; trial < trials; trial++) + { + T output[ItemsPerThread]; + + adjacent_difference.SubtractLeftPartialTile(input, output, minus{}, tile_size); + + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + input[i] += output[i]; + } + + // Change the tile_size to even out the distribution + tile_size = (tile_size + tile_size_diff) % (BlockSize * ItemsPerThread); + __syncthreads(); + } + + hipcub::StoreDirectStriped(lid, d_output + block_offset, input); + } +}; + +struct subtract_right +{ + template + __device__ static void run(const T* d_input, T* d_output, unsigned int trials) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + + T input[ItemsPerThread]; + hipcub::LoadDirectStriped(lid, d_input + block_offset, input); + + hipcub::BlockAdjacentDifference adjacent_difference; + + #pragma nounroll + for(unsigned int trial = 0; trial < trials; trial++) + { + T output[ItemsPerThread]; + if(WithTile) + { + adjacent_difference.SubtractRight(input, output, minus{}, T(123)); + } + else + { + adjacent_difference.SubtractRight(input, output, minus{}); + } + + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + input[i] += output[i]; + } + + __syncthreads(); + } + + hipcub::StoreDirectStriped(lid, d_output + block_offset, input); + } +}; + +struct subtract_right_partial_tile +{ + template + __device__ static void run(const T* d_input, int* tile_sizes, T* d_output, unsigned int trials) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + + T input[ItemsPerThread]; + hipcub::LoadDirectStriped(lid, d_input + block_offset, input); + + hipcub::BlockAdjacentDifference adjacent_difference; + + int tile_size = tile_sizes[blockIdx.x]; + + // Try to evenly distribute the length of tile_sizes between all the trials + const auto tile_size_diff = (BlockSize * ItemsPerThread) / trials + 1; + + #pragma nounroll + for(unsigned int trial = 0; trial < trials; trial++) + { + T output[ItemsPerThread]; + + adjacent_difference.SubtractRightPartialTile(input, output, minus{}, tile_size); + + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + input[i] += output[i]; + } + + // Change the tile_size to even out the distribution + tile_size = (tile_size + tile_size_diff) % (BlockSize * ItemsPerThread); + __syncthreads(); + } + + hipcub::StoreDirectStriped(lid, d_output + block_offset, input); + } +}; + +template +auto run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) + -> std::enable_if_t::value + && !std::is_same::value> +{ + constexpr auto items_per_block = BlockSize * ItemsPerThread; + const auto num_blocks = (N + items_per_block - 1) / items_per_block; + // Round up size to the next multiple of items_per_block + const auto size = num_blocks * items_per_block; + + const std::vector input = benchmark_utils::get_random_data(size, T(0), T(10)); + T* d_input; + T* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_output, input.size() * sizeof(T))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + + for(auto _ : state) + { + auto start = std::chrono::high_resolution_clock::now(); + + hipLaunchKernelGGL( + HIP_KERNEL_NAME(kernel), + dim3(num_blocks), dim3(BlockSize), 0, stream, + d_input, d_output, Trials + ); + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); + } + state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); + state.SetItemsProcessed(state.iterations() * Trials * size); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); +} + +template +auto run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) + -> std::enable_if_t::value + || std::is_same::value> +{ + constexpr auto items_per_block = BlockSize * ItemsPerThread; + const auto num_blocks = (N + items_per_block - 1) / items_per_block; + // Round up size to the next multiple of items_per_block + const auto size = num_blocks * items_per_block; + + const std::vector input = benchmark_utils::get_random_data(size, T(0), T(10)); + const std::vector tile_sizes + = benchmark_utils::get_random_data(num_blocks, 0, items_per_block); + + T* d_input; + int* d_tile_sizes; + T* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_tile_sizes, tile_sizes.size() * sizeof(tile_sizes[0]))); + HIP_CHECK(hipMalloc(&d_output, input.size() * sizeof(T))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_tile_sizes, tile_sizes.data(), + tile_sizes.size() * sizeof(tile_sizes[0]), + hipMemcpyHostToDevice + ) + ); + + for(auto _ : state) + { + auto start = std::chrono::high_resolution_clock::now(); + + hipLaunchKernelGGL( + HIP_KERNEL_NAME(kernel), + dim3(num_blocks), dim3(BlockSize), 0, stream, + d_input, d_tile_sizes, d_output, Trials + ); + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); + } + state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); + state.SetItemsProcessed(state.iterations() * Trials * size); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_tile_sizes)); + HIP_CHECK(hipFree(d_output)); +} + +#define CREATE_BENCHMARK(T, BS, IPT, WITH_TILE) \ +benchmark::RegisterBenchmark( \ + (std::string("block_adjacent_difference<" #T ", " #BS ">.") + name + ("<" #IPT ", " #WITH_TILE ">")).c_str(), \ + &run_benchmark, \ + stream, size \ +) + +#define BENCHMARK_TYPE(type, block, with_tile) \ + CREATE_BENCHMARK(type, block, 1, with_tile), \ + CREATE_BENCHMARK(type, block, 3, with_tile), \ + CREATE_BENCHMARK(type, block, 4, with_tile), \ + CREATE_BENCHMARK(type, block, 8, with_tile), \ + CREATE_BENCHMARK(type, block, 16, with_tile), \ + CREATE_BENCHMARK(type, block, 32, with_tile) + +template +void add_benchmarks(const std::string& name, + std::vector& benchmarks, + hipStream_t stream, + size_t size) +{ + std::vector bs = + { + BENCHMARK_TYPE(int, 256, false), + BENCHMARK_TYPE(float, 256, false), + BENCHMARK_TYPE(int8_t, 256, false), + BENCHMARK_TYPE(long long, 256, false), + BENCHMARK_TYPE(double, 256, false) + }; + + if(!std::is_same::value + && !std::is_same::value) { + bs.insert(bs.end(), { + BENCHMARK_TYPE(int, 256, true), + BENCHMARK_TYPE(float, 256, true), + BENCHMARK_TYPE(int8_t, 256, true), + BENCHMARK_TYPE(long long, 256, true), + BENCHMARK_TYPE(double, 256, true) + }); + } + + benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); +} + +int main(int argc, char *argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + + // HIP + hipStream_t stream = 0; // default + hipDeviceProp_t devProp; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); + std::cout << "[HIP] Device name: " << devProp.name << std::endl; + + // Add benchmarks + std::vector benchmarks; + add_benchmarks("SubtractLeft", benchmarks, stream, size); + add_benchmarks("SubtractRight", benchmarks, stream, size); + add_benchmarks("SubtractLeftPartialTile", benchmarks, stream, size); + add_benchmarks("SubtractRightPartialTile", benchmarks, stream, size); + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} \ No newline at end of file diff --git a/benchmark/benchmark_device_adjacent_difference.cpp b/benchmark/benchmark_device_adjacent_difference.cpp new file mode 100644 index 00000000..0e782afe --- /dev/null +++ b/benchmark/benchmark_device_adjacent_difference.cpp @@ -0,0 +1,258 @@ +// MIT License +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// CUB's implementation of DeviceRunLengthEncode has unused parameters, +// disable the warning because all warnings are threated as errors: + +#include "common_benchmark_header.hpp" + +#include + +#include "cmdparser.hpp" + +#include + +#include + +#include +#include +#include +#include +#include + +namespace +{ + +#ifndef DEFAULT_N +constexpr std::size_t DEFAULT_N = 1024 * 1024 * 128; +#endif + +constexpr unsigned int batch_size = 10; +constexpr unsigned int warmup_size = 5; + +template +auto dispatch_adjacent_difference(std::true_type /*left*/, + std::true_type /*copy*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractLeftCopy( + temporary_storage, storage_size, input, output, std::forward(args)...); +} + +template +auto dispatch_adjacent_difference(std::false_type /*left*/, + std::true_type /*copy*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractRightCopy( + temporary_storage, storage_size, input, output, std::forward(args)...); +} + +template +auto dispatch_adjacent_difference(std::true_type /*left*/, + std::false_type /*copy*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractLeft( + temporary_storage, storage_size, input, std::forward(args)...); +} + +template +auto dispatch_adjacent_difference(std::false_type /*left*/, + std::false_type /*copy*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractRight( + temporary_storage, storage_size, input, std::forward(args)...); +} + +template +void run_benchmark(benchmark::State& state, const std::size_t size, const hipStream_t stream) +{ + using output_type = T; + + static constexpr bool debug_synchronous = false; + + // Generate data + const std::vector input = benchmark_utils::get_random_data(size, 1, 100); + + T* d_input; + output_type* d_output = nullptr; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK( + hipMemcpy(d_input, input.data(), input.size() * sizeof(input[0]), hipMemcpyHostToDevice)); + + if(copy) + { + HIP_CHECK(hipMalloc(&d_output, size * sizeof(output_type))); + } + + static constexpr std::integral_constant left_tag; + static constexpr std::integral_constant copy_tag; + + // Allocate temporary storage + std::size_t temp_storage_size{}; + void* d_temp_storage = nullptr; + + const auto launch = [&] { + return dispatch_adjacent_difference(left_tag, + copy_tag, + d_temp_storage, + temp_storage_size, + d_input, + d_output, + size, + hipcub::Sum{}, + stream, + debug_synchronous); + }; + HIP_CHECK(launch()); + HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(launch()); + } + HIP_CHECK(hipDeviceSynchronize()); + + // Run + for(auto _ : state) + { + auto start = std::chrono::high_resolution_clock::now(); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(launch()); + } + HIP_CHECK(hipStreamSynchronize(stream)); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds + = std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); + } + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + hipFree(d_input); + if(copy) + { + hipFree(d_output); + } + hipFree(d_temp_storage); +} + +} // namespace + +using namespace std::string_literals; + +#define CREATE_BENCHMARK(T, left, copy) \ + benchmark::RegisterBenchmark(("Subtract" + (left ? "Left"s : "Right"s) \ + + (copy ? "Copy"s : ""s) + "<" #T ">") \ + .c_str(), \ + &run_benchmark, \ + size, \ + stream) + +// clang-format off +#define CREATE_BENCHMARKS(T) \ + CREATE_BENCHMARK(T, true, false), \ + CREATE_BENCHMARK(T, true, true), \ + CREATE_BENCHMARK(T, false, false), \ + CREATE_BENCHMARK(T, false, true) +// clang-format on + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + + // HIP + const hipStream_t stream = 0; // default + hipDeviceProp_t devProp; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); + std::cout << "[HIP] Device name: " << devProp.name << std::endl; + + using custom_float2 = benchmark_utils::custom_type; + using custom_double2 = benchmark_utils::custom_type; + + // Add benchmarks + const std::vector benchmarks = { + CREATE_BENCHMARKS(int), + CREATE_BENCHMARKS(std::int64_t), + + CREATE_BENCHMARKS(uint8_t), + + CREATE_BENCHMARKS(float), + CREATE_BENCHMARKS(double), + + CREATE_BENCHMARKS(custom_float2), + CREATE_BENCHMARKS(custom_double2), + }; + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + + return 0; +} diff --git a/benchmark/benchmark_device_select.cpp b/benchmark/benchmark_device_select.cpp index 3e6c6f6d..bcd9c524 100644 --- a/benchmark/benchmark_device_select.cpp +++ b/benchmark/benchmark_device_select.cpp @@ -368,6 +368,139 @@ void run_unique_benchmark(benchmark::State& state, hipFree(d_temp_storage); } +template +void run_unique_by_key_benchmark(benchmark::State& state, + size_t size, + const hipStream_t stream, + float discontinuity_probability) +{ + hipcub::Sum op; + + std::vector input_keys(size); + { + auto input01 = benchmark_utils::get_random_data01(size, discontinuity_probability); + auto acc = input01[0]; + + input_keys[0] = acc; + + for (size_t i = 1; i < input01.size(); i++) + { + input_keys[i] = op(acc, input01[i]); + } + } + + const auto input_values = benchmark_utils::get_random_data(size, ValueT(-1000), ValueT(1000)); + unsigned int selected_count_output = 0; + + KeyT* d_keys_input; + ValueT* d_values_input; + KeyT* d_keys_output; + ValueT* d_values_output; + unsigned int* d_selected_count_output; + + HIP_CHECK(hipMalloc(&d_keys_input, input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(hipMalloc(&d_values_input, input_values.size() * sizeof(input_values[0]))); + HIP_CHECK(hipMalloc(&d_keys_output, input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(hipMalloc(&d_values_output, input_values.size() * sizeof(input_values[0]))); + HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(selected_count_output))); + + HIP_CHECK( + hipMemcpy( + d_keys_input, + input_keys.data(), + input_keys.size() * sizeof(input_keys[0]), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_values_input, + input_values.data(), + input_values.size() * sizeof(input_values[0]), + hipMemcpyHostToDevice + ) + ); + + // Allocate temporary storage memory + size_t temp_storage_size_bytes; + + // Get size of d_temp_storage + HIP_CHECK( + hipcub::DeviceSelect::UniqueByKey( + nullptr, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + d_keys_output, + d_values_output, + d_selected_count_output, + input_keys.size(), + stream + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + // allocate temporary storage + void* d_temp_storage = nullptr; + HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for (size_t i = 0; i < 10; i++) + { + HIP_CHECK( + hipcub::DeviceSelect::UniqueByKey( + d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + d_keys_output, + d_values_output, + d_selected_count_output, + input_keys.size(), + stream + ) + ); + } + HIP_CHECK(hipDeviceSynchronize()); + + const unsigned int batch_size = 10; + for (auto _ : state) + { + auto start = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < batch_size; i++) + { + HIP_CHECK( + hipcub::DeviceSelect::UniqueByKey( + d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + d_keys_output, + d_values_output, + d_selected_count_output, + input_keys.size(), + stream + ) + ); + } + HIP_CHECK(hipDeviceSynchronize()); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); + } + state.SetBytesProcessed(state.iterations() * batch_size * size * (sizeof(KeyT) + sizeof(ValueT))); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + hipFree(d_keys_input); + hipFree(d_values_input); + hipFree(d_keys_output); + hipFree(d_values_output); + hipFree(d_selected_count_output); + hipFree(d_temp_storage); +} + #define CREATE_SELECT_FLAGGED_BENCHMARK(T, F, p) \ benchmark::RegisterBenchmark( \ ("select_flagged<" #T "," #F ", "#T", unsigned int>(p = " #p")"), \ @@ -386,6 +519,12 @@ benchmark::RegisterBenchmark( \ &run_unique_benchmark, size, stream, p \ ) +#define CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, p) \ +benchmark::RegisterBenchmark( \ + ("unique_by_key<" #K ", "#V", unsigned int>(p = " #p")"), \ + &run_unique_by_key_benchmark, size, stream, p \ +) + #define BENCHMARK_FLAGGED_TYPE(type, value) \ CREATE_SELECT_FLAGGED_BENCHMARK(type, value, 0.05f), \ CREATE_SELECT_FLAGGED_BENCHMARK(type, value, 0.25f), \ @@ -404,6 +543,12 @@ benchmark::RegisterBenchmark( \ CREATE_UNIQUE_BENCHMARK(type, 0.5f), \ CREATE_UNIQUE_BENCHMARK(type, 0.75f) +#define BENCHMARK_UNIQUE_BY_KEY_TYPE(key_type, value_type) \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(key_type, value_type, 0.05f), \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(key_type, value_type, 0.25f), \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(key_type, value_type, 0.5f), \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(key_type, value_type, 0.75f) + int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); @@ -449,7 +594,14 @@ int main(int argc, char *argv[]) BENCHMARK_UNIQUE_TYPE(double), BENCHMARK_UNIQUE_TYPE(uint8_t), BENCHMARK_UNIQUE_TYPE(int8_t), - BENCHMARK_UNIQUE_TYPE(custom_int_double) + BENCHMARK_UNIQUE_TYPE(custom_int_double), + + BENCHMARK_UNIQUE_BY_KEY_TYPE(int, int), + BENCHMARK_UNIQUE_BY_KEY_TYPE(float, double), + BENCHMARK_UNIQUE_BY_KEY_TYPE(double, custom_double2), + BENCHMARK_UNIQUE_BY_KEY_TYPE(uint8_t, uint8_t), + BENCHMARK_UNIQUE_BY_KEY_TYPE(int8_t, double), + BENCHMARK_UNIQUE_BY_KEY_TYPE(custom_int_double, custom_int_double) }; // Use manual timing diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 95908cd7..3090a777 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -46,54 +46,54 @@ if(HIP_COMPILER STREQUAL "nvcc") if(NOT DEFINED CUB_INCLUDE_DIR) file( - DOWNLOAD https://github.com/NVIDIA/cub/archive/1.15.0.zip - ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip + DOWNLOAD https://github.com/NVIDIA/cub/archive/1.16.0.zip + ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip STATUS cub_download_status LOG cub_download_log ) list(GET cub_download_status 0 cub_download_error_code) if(cub_download_error_code) message(FATAL_ERROR "Error: downloading " - "https://github.com/NVIDIA/cub/archive/1.15.0.zip failed " + "https://github.com/NVIDIA/cub/archive/1.16.0.zip failed " "error_code: ${cub_download_error_code} " "log: ${cub_download_log} " ) endif() execute_process( - COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip + COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} RESULT_VARIABLE cub_unpack_error_code ) if(cub_unpack_error_code) - message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip failed") + message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip failed") endif() - set(CUB_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0/ CACHE PATH "") + set(CUB_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0/ CACHE PATH "") endif() if(NOT DEFINED THRUST_INCLUDE_DIR) file( - DOWNLOAD https://github.com/NVIDIA/thrust/archive/1.15.0.zip - ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip + DOWNLOAD https://github.com/NVIDIA/thrust/archive/1.16.0.zip + ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip STATUS thrust_download_status LOG thrust_download_log ) list(GET thrust_download_status 0 thrust_download_error_code) if(thrust_download_error_code) message(FATAL_ERROR "Error: downloading " - "https://github.com/NVIDIA/thrust/archive/1.15.0.zip failed " + "https://github.com/NVIDIA/thrust/archive/1.16.0.zip failed " "error_code: ${thrust_download_error_code} " "log: ${thrust_download_log} " ) endif() execute_process( - COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip + COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} RESULT_VARIABLE thrust_unpack_error_code ) if(thrust_unpack_error_code) - message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip failed") + message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip failed") endif() - set(THRUST_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0/ CACHE PATH "") + set(THRUST_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0/ CACHE PATH "") endif() else() # rocPRIM (only for ROCm platform) diff --git a/cmake/SetupNVCC.cmake b/cmake/SetupNVCC.cmake index aae07b6c..925baa1b 100644 --- a/cmake/SetupNVCC.cmake +++ b/cmake/SetupNVCC.cmake @@ -63,7 +63,7 @@ function(hip_cuda_detect_cc out_variable) endif() if(NOT HIP_CUDA_detected_cc) - set(HIP_CUDA_detected_cc "35") + set(HIP_CUDA_detected_cc "53") set(${out_variable} ${HIP_CUDA_detected_cc} PARENT_SCOPE) else() set(${out_variable} ${HIP_CUDA_detected_cc} PARENT_SCOPE) @@ -81,6 +81,7 @@ endif() # Get CUDA enable_language("CUDA") +set(CMAKE_CUDA_STANDARD 14) # Suppressing warnings set(HIP_NVCC_FLAGS " ${HIP_NVCC_FLAGS} -Wno-deprecated-gpu-targets -Xcompiler -Wno-return-type -Wno-deprecated-declarations ") @@ -96,11 +97,14 @@ endif() set(NVGPU_TARGETS "${DEFAULT_NVGPU_TARGETS}" CACHE STRING "List of NVIDIA GPU targets (compute capabilities), for example \"35;50\"" ) -# Generate compiler flags based on targeted CUDA architectures -foreach(CUDA_ARCH ${NVGPU_TARGETS}) - list(APPEND HIP_NVCC_FLAGS "--generate-code arch=compute_${CUDA_ARCH},code=sm_${CUDA_ARCH} ") - list(APPEND HIP_NVCC_FLAGS "--generate-code arch=compute_${CUDA_ARCH},code=compute_${CUDA_ARCH} ") -endforeach() +set(CMAKE_CUDA_ARCHITECTURES ${NVGPU_TARGETS}) +# Generate compiler flags based on targeted CUDA architectures if CMake doesn't. (Controlled by policy CP0104, on by default after 3.18) +if(CMAKE_VERSION VERSION_LESS "3.18") + foreach(CUDA_ARCH ${NVGPU_TARGETS}) + list(APPEND HIP_NVCC_FLAGS "--generate-code arch=compute_${CUDA_ARCH},code=sm_${CUDA_ARCH} ") + list(APPEND HIP_NVCC_FLAGS "--generate-code arch=compute_${CUDA_ARCH},code=compute_${CUDA_ARCH} ") + endforeach() +endif() execute_process( COMMAND ${HIP_HIPCONFIG_EXECUTABLE} --cpp_config diff --git a/hipcub/include/hipcub/backend/cub/device/device_adjacent_difference.hpp b/hipcub/include/hipcub/backend/cub/device/device_adjacent_difference.hpp new file mode 100644 index 00000000..8893aaac --- /dev/null +++ b/hipcub/include/hipcub/backend/cub/device/device_adjacent_difference.hpp @@ -0,0 +1,123 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_CUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define HIPCUB_CUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#include "../../../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceAdjacentDifference +{ + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractLeftCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return hipCUDAErrorTohipError( + ::cub::DeviceAdjacentDifference::SubtractLeftCopy( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ) + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractLeft(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return hipCUDAErrorTohipError( + ::cub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ) + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractRightCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return hipCUDAErrorTohipError( + ::cub::DeviceAdjacentDifference::SubtractRightCopy( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ) + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractRight(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return hipCUDAErrorTohipError( + ::cub::DeviceAdjacentDifference::SubtractRight( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ) + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_CUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ diff --git a/hipcub/include/hipcub/backend/cub/device/device_radix_sort.hpp b/hipcub/include/hipcub/backend/cub/device/device_radix_sort.hpp index 571fdeae..c3849964 100644 --- a/hipcub/include/hipcub/backend/cub/device/device_radix_sort.hpp +++ b/hipcub/include/hipcub/backend/cub/device/device_radix_sort.hpp @@ -38,7 +38,7 @@ BEGIN_HIPCUB_NAMESPACE struct DeviceRadixSort { - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairs(void * d_temp_storage, size_t& temp_storage_bytes, @@ -46,7 +46,7 @@ struct DeviceRadixSort KeyT * d_keys_out, const ValueT * d_values_in, ValueT * d_values_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -63,13 +63,13 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairs(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, DoubleBuffer& d_values, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -85,7 +85,7 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairsDescending(void * d_temp_storage, size_t& temp_storage_bytes, @@ -93,7 +93,7 @@ struct DeviceRadixSort KeyT * d_keys_out, const ValueT * d_values_in, ValueT * d_values_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -111,13 +111,13 @@ struct DeviceRadixSort } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairsDescending(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, DoubleBuffer& d_values, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -133,13 +133,13 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeys(void * d_temp_storage, size_t& temp_storage_bytes, const KeyT * d_keys_in, KeyT * d_keys_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -155,12 +155,12 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeys(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -176,13 +176,13 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeysDescending(void * d_temp_storage, size_t& temp_storage_bytes, const KeyT * d_keys_in, KeyT * d_keys_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -198,12 +198,12 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeysDescending(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, diff --git a/hipcub/include/hipcub/backend/cub/device/device_select.hpp b/hipcub/include/hipcub/backend/cub/device/device_select.hpp index f753ac02..ecacb970 100644 --- a/hipcub/include/hipcub/backend/cub/device/device_select.hpp +++ b/hipcub/include/hipcub/backend/cub/device/device_select.hpp @@ -116,6 +116,36 @@ class DeviceSelect ) ); } + + template < + typename KeyIteratorT, + typename ValueIteratorT, + typename OutputKeyIteratorT, + typename OutputValueIteratorT, + typename NumSelectedIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + hipError_t UniqueByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys_input, + ValueIteratorT d_values_input, + OutputKeyIteratorT d_keys_output, + OutputValueIteratorT d_values_output, + NumSelectedIteratorT d_num_selected_out, + int num_items, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return hipCUDAErrorTohipError( + ::cub::DeviceSelect::UniqueByKey( + d_temp_storage, temp_storage_bytes, + d_keys_input, d_values_input, + d_keys_output, d_values_output, + d_num_selected_out, num_items, + stream, debug_synchronous + ) + ); + } }; END_HIPCUB_NAMESPACE diff --git a/hipcub/include/hipcub/backend/cub/hipcub.hpp b/hipcub/include/hipcub/backend/cub/hipcub.hpp index 3504742b..b5aa24af 100644 --- a/hipcub/include/hipcub/backend/cub/hipcub.hpp +++ b/hipcub/include/hipcub/backend/cub/hipcub.hpp @@ -80,6 +80,7 @@ // Device functions must be wrapped so they return // hipError_t instead of cudaError_t +#include "device/device_adjacent_difference.hpp" #include "device/device_histogram.hpp" #include "device/device_radix_sort.hpp" #include "device/device_reduce.hpp" diff --git a/hipcub/include/hipcub/backend/rocprim/block/block_adjacent_difference.hpp b/hipcub/include/hipcub/backend/rocprim/block/block_adjacent_difference.hpp index 33bd8518..42080f07 100644 --- a/hipcub/include/hipcub/backend/rocprim/block/block_adjacent_difference.hpp +++ b/hipcub/include/hipcub/backend/rocprim/block/block_adjacent_difference.hpp @@ -81,57 +81,73 @@ class BlockAdjacentDifference } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads(head_flags, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op, T tile_predecessor_item) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads(head_flags, tile_predecessor_item, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_tails(tail_flags, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op, T tile_successor_item) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_tails(tail_flags, tile_successor_item, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads_and_tails( head_flags, tail_flags, input, flag_op, temp_storage_ ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], FlagT (&tail_flags)[ITEMS_PER_THREAD], @@ -139,13 +155,16 @@ class BlockAdjacentDifference T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads_and_tails( head_flags, tail_flags, tile_successor_item, input, flag_op, temp_storage_ ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], T tile_predecessor_item, @@ -153,13 +172,16 @@ class BlockAdjacentDifference T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads_and_tails( head_flags, tile_predecessor_item, tail_flags, input, flag_op, temp_storage_ ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP } template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] HIPCUB_DEVICE inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], T tile_predecessor_item, @@ -168,10 +190,82 @@ class BlockAdjacentDifference T (&input)[ITEMS_PER_THREAD], FlagOp flag_op) { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") base_type::flag_heads_and_tails( head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input, flag_op, temp_storage_ ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + HIPCUB_DEVICE inline + void SubtractLeft(T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op) + { + base_type::subtract_left( + input, output, difference_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractLeft(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_predecessor_item) + { + base_type::subtract_left( + input, output, difference_op, tile_predecessor_item, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractLeftPartialTile(T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items) + { + base_type::subtract_left_partial( + input, output, difference_op, valid_items, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRight(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op) + { + base_type::subtract_right( + input, output, difference_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRight(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_successor_item) + { + base_type::subtract_right( + input, output, difference_op, tile_successor_item, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRightPartialTile(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items) + { + base_type::subtract_right_partial( + input, output, difference_op, valid_items, temp_storage_ + ); } private: diff --git a/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp b/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp index d6df9da4..cb308a1a 100644 --- a/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp +++ b/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp @@ -105,9 +105,9 @@ class BlockRadixRank typedef unsigned short DigitCounter; // Integer type for packing DigitCounters into columns of shared memory banks - typedef typename If<(SMEM_CONFIG == hipSharedMemBankSizeEightByte), + typedef typename std::conditional<(SMEM_CONFIG == hipSharedMemBankSizeEightByte), unsigned long long, - unsigned int>::Type PackedCounter; + unsigned int>::type PackedCounter; enum { diff --git a/hipcub/include/hipcub/backend/rocprim/device/device_adjacent_difference.hpp b/hipcub/include/hipcub/backend/rocprim/device/device_adjacent_difference.hpp new file mode 100644 index 00000000..f46a622b --- /dev/null +++ b/hipcub/include/hipcub/backend/rocprim/device/device_adjacent_difference.hpp @@ -0,0 +1,116 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#include "../../../config.hpp" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceAdjacentDifference +{ + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractLeftCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return ::rocprim::adjacent_difference( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractLeft(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return ::rocprim::adjacent_difference_inplace( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractRightCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return ::rocprim::adjacent_difference_right( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION hipError_t + SubtractRight(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return ::rocprim::adjacent_difference_right_inplace( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ diff --git a/hipcub/include/hipcub/backend/rocprim/device/device_radix_sort.hpp b/hipcub/include/hipcub/backend/rocprim/device/device_radix_sort.hpp index 3cde8a0a..1148ef69 100644 --- a/hipcub/include/hipcub/backend/rocprim/device/device_radix_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/device/device_radix_sort.hpp @@ -40,7 +40,7 @@ BEGIN_HIPCUB_NAMESPACE struct DeviceRadixSort { - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairs(void * d_temp_storage, size_t& temp_storage_bytes, @@ -48,7 +48,7 @@ struct DeviceRadixSort KeyT * d_keys_out, const ValueT * d_values_in, ValueT * d_values_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -62,13 +62,13 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairs(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, DoubleBuffer& d_values, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -87,7 +87,7 @@ struct DeviceRadixSort return error; } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairsDescending(void * d_temp_storage, size_t& temp_storage_bytes, @@ -95,7 +95,7 @@ struct DeviceRadixSort KeyT * d_keys_out, const ValueT * d_values_in, ValueT * d_values_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -109,13 +109,13 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortPairsDescending(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, DoubleBuffer& d_values, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -134,13 +134,13 @@ struct DeviceRadixSort return error; } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeys(void * d_temp_storage, size_t& temp_storage_bytes, const KeyT * d_keys_in, KeyT * d_keys_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -154,12 +154,12 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeys(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -176,13 +176,13 @@ struct DeviceRadixSort return error; } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeysDescending(void * d_temp_storage, size_t& temp_storage_bytes, const KeyT * d_keys_in, KeyT * d_keys_out, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, @@ -196,12 +196,12 @@ struct DeviceRadixSort ); } - template + template HIPCUB_RUNTIME_FUNCTION static hipError_t SortKeysDescending(void * d_temp_storage, size_t& temp_storage_bytes, DoubleBuffer& d_keys, - int num_items, + NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, hipStream_t stream = 0, diff --git a/hipcub/include/hipcub/backend/rocprim/device/device_select.hpp b/hipcub/include/hipcub/backend/rocprim/device/device_select.hpp index 0f790608..a7847644 100644 --- a/hipcub/include/hipcub/backend/rocprim/device/device_select.hpp +++ b/hipcub/include/hipcub/backend/rocprim/device/device_select.hpp @@ -110,6 +110,34 @@ class DeviceSelect stream, debug_synchronous ); } + + template < + typename KeyIteratorT, + typename ValueIteratorT, + typename OutputKeyIteratorT, + typename OutputValueIteratorT, + typename NumSelectedIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + hipError_t UniqueByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys_input, + ValueIteratorT d_values_input, + OutputKeyIteratorT d_keys_output, + OutputValueIteratorT d_values_output, + NumSelectedIteratorT d_num_selected_out, + int num_items, + hipStream_t stream = 0, + bool debug_synchronous = false) + { + return ::rocprim::unique_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_input, d_values_input, + d_keys_output, d_values_output, + d_num_selected_out, num_items, hipcub::Equality(), + stream, debug_synchronous + ); + } }; END_HIPCUB_NAMESPACE diff --git a/hipcub/include/hipcub/backend/rocprim/hipcub.hpp b/hipcub/include/hipcub/backend/rocprim/hipcub.hpp index bce61b8a..eacf10c5 100644 --- a/hipcub/include/hipcub/backend/rocprim/hipcub.hpp +++ b/hipcub/include/hipcub/backend/rocprim/hipcub.hpp @@ -76,6 +76,7 @@ #include "block/block_store.hpp" // Device +#include "device/device_adjacent_difference.hpp" #include "device/device_histogram.hpp" #include "device/device_radix_sort.hpp" #include "device/device_reduce.hpp" diff --git a/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_input_iterator.hpp b/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_input_iterator.hpp index 4a264d3d..156a4191 100644 --- a/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_input_iterator.hpp +++ b/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_input_iterator.hpp @@ -69,7 +69,7 @@ class CacheModifiedInputIterator __host__ __device__ __forceinline__ CacheModifiedInputIterator( ValueType* ptr) ///< Native pointer to wrap : - ptr(const_cast::Type *>(ptr)) + ptr(const_cast::type *>(ptr)) {} /// Postfix increment diff --git a/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_output_iterator.hpp b/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_output_iterator.hpp index c4f5a82e..e330d25f 100644 --- a/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_output_iterator.hpp +++ b/hipcub/include/hipcub/backend/rocprim/iterator/cache_modified_output_iterator.hpp @@ -91,7 +91,7 @@ class CacheModifiedOutputIterator __host__ __device__ __forceinline__ CacheModifiedOutputIterator( QualifiedValueType* ptr) ///< Native pointer to wrap : - ptr(const_cast::Type *>(ptr)) + ptr(const_cast::type *>(ptr)) {} /// Postfix increment diff --git a/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp b/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp index 32069476..1fda5711 100644 --- a/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp +++ b/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp @@ -82,6 +82,26 @@ struct Sum } }; +struct Difference +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a - b; + } +}; + +struct Division +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a / b; + } +}; + struct Max { template @@ -132,6 +152,127 @@ struct ArgMin } }; +template +struct CastOp +{ + template + HIPCUB_HOST_DEVICE inline + B operator()(const A &a) const + { + return (B)a; + } +}; + +template +class SwizzleScanOp +{ +private: + ScanOp scan_op; + +public: + HIPCUB_HOST_DEVICE inline + SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) + { + } + + template + HIPCUB_HOST_DEVICE inline + T operator()(const T &a, const T &b) + { + T _a(a); + T _b(b); + + return scan_op(_b, _a); + } +}; + +template +struct ReduceBySegmentOp +{ + ReductionOpT op; + + HIPCUB_HOST_DEVICE inline + ReduceBySegmentOp() + { + } + + HIPCUB_HOST_DEVICE inline + ReduceBySegmentOp(ReductionOpT op) : op(op) + { + } + + template + HIPCUB_HOST_DEVICE inline + KeyValuePairT operator()( + const KeyValuePairT &first, + const KeyValuePairT &second) + { + KeyValuePairT retval; + retval.key = first.key + second.key; + retval.value = (second.key) ? + second.value : + op(first.value, second.value); + return retval; + } +}; + +template +struct ReduceByKeyOp +{ + ReductionOpT op; + + HIPCUB_HOST_DEVICE inline + ReduceByKeyOp() + { + } + + HIPCUB_HOST_DEVICE inline + ReduceByKeyOp(ReductionOpT op) : op(op) + { + } + + template + HIPCUB_HOST_DEVICE inline + KeyValuePairT operator()( + const KeyValuePairT &first, + const KeyValuePairT &second) + { + KeyValuePairT retval = second; + + if (first.key == second.key) + { + retval.value = op(first.value, retval.value); + } + return retval; + } +}; + +template +struct BinaryFlip +{ + BinaryOpT binary_op; + + HIPCUB_HOST_DEVICE + explicit BinaryFlip(BinaryOpT binary_op) : binary_op(binary_op) + { + } + + template + HIPCUB_DEVICE auto + operator()(T &&t, U &&u) -> decltype(binary_op(std::forward(u), + std::forward(t))) + { + return binary_op(std::forward(u), std::forward(t)); + } +}; + +template +HIPCUB_HOST_DEVICE +BinaryFlip MakeBinaryFlip(BinaryOpT binary_op) +{ + return BinaryFlip(binary_op); +} + namespace detail { diff --git a/hipcub/include/hipcub/backend/rocprim/util_type.hpp b/hipcub/include/hipcub/backend/rocprim/util_type.hpp index bb152543..48937272 100644 --- a/hipcub/include/hipcub/backend/rocprim/util_type.hpp +++ b/hipcub/include/hipcub/backend/rocprim/util_type.hpp @@ -49,26 +49,26 @@ using NullType = ::rocprim::empty_type; #endif -template -struct If +template struct +[[deprecated("[Since 1.16] If is deprecated use std::conditional instead.")]] If { using Type = typename std::conditional::type; }; -template -struct IsPointer +template struct +[[deprecated("[Since 1.16] IsPointer is deprecated use std::is_pointer instead.")]] IsPointer { static constexpr bool VALUE = std::is_pointer::value; }; -template -struct IsVolatile +template struct +[[deprecated("[Since 1.16] IsVolatile is deprecated use std::is_volatile instead.")]] IsVolatile { static constexpr bool VALUE = std::is_volatile::value; }; -template -struct RemoveQualifiers +template struct +[[deprecated("[Since 1.16] RemoveQualifiers is deprecated use std::remove_cv instead.")]] RemoveQualifiers { using Type = typename std::remove_cv::type; }; @@ -189,7 +189,7 @@ using is_integral_or_enum = } template -__host__ __device__ __forceinline__ constexpr NumeratorT +HIPCUB_HOST_DEVICE __forceinline__ constexpr NumeratorT DivideAndRoundUp(NumeratorT n, DenominatorT d) { static_assert(hipcub::detail::is_integral_or_enum::value && @@ -284,28 +284,28 @@ struct UnitWord }; /// Biggest shuffle word that T is a whole multiple of and is not larger than the alignment of T - typedef typename If::IS_MULTIPLE, + typedef typename std::conditional::IS_MULTIPLE, unsigned int, - typename If::IS_MULTIPLE, + typename std::conditional::IS_MULTIPLE, unsigned short, - unsigned char>::Type>::Type ShuffleWord; + unsigned char>::type>::type ShuffleWord; /// Biggest volatile word that T is a whole multiple of and is not larger than the alignment of T - typedef typename If::IS_MULTIPLE, + typedef typename std::conditional::IS_MULTIPLE, unsigned long long, - ShuffleWord>::Type VolatileWord; + ShuffleWord>::type VolatileWord; /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T - typedef typename If::IS_MULTIPLE, + typedef typename std::conditional::IS_MULTIPLE, ulonglong2, - VolatileWord>::Type DeviceWord; + VolatileWord>::type DeviceWord; /// Biggest texture reference word that T is a whole multiple of and is not larger than the alignment of T - typedef typename If::IS_MULTIPLE, + typedef typename std::conditional::IS_MULTIPLE, uint4, - typename If::IS_MULTIPLE, + typename std::conditional::IS_MULTIPLE, uint2, - ShuffleWord>::Type>::Type TextureWord; + ShuffleWord>::type>::type TextureWord; }; @@ -364,16 +364,15 @@ struct Uninitialized /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T typedef typename UnitWord::DeviceWord DeviceWord; - enum - { - WORDS = sizeof(T) / sizeof(DeviceWord) - }; + static constexpr std::size_t DATA_SIZE = sizeof(T); + static constexpr std::size_t WORD_SIZE = sizeof(DeviceWord); + static constexpr std::size_t WORDS = DATA_SIZE / WORD_SIZE; /// Backing storage DeviceWord storage[WORDS]; /// Alias - __host__ __device__ __forceinline__ T& Alias() + HIPCUB_HOST_DEVICE __forceinline__ T& Alias() { return reinterpret_cast(*this); } @@ -440,26 +439,30 @@ struct BaseTraits }; - static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { return key; } - static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { return key; } - static __host__ __device__ __forceinline__ T Max() + static HIPCUB_HOST_DEVICE __forceinline__ T Max() { - UnsignedBits retval = MAX_KEY; - return reinterpret_cast(retval); + UnsignedBits retval_bits = MAX_KEY; + T retval; + memcpy(&retval, &retval_bits, sizeof(T)); + return retval; } - static __host__ __device__ __forceinline__ T Lowest() + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() { - UnsignedBits retval = LOWEST_KEY; - return reinterpret_cast(retval); + UnsignedBits retval_bits = LOWEST_KEY; + T retval; + memcpy(&retval, &retval_bits, sizeof(T)); + return retval; } }; @@ -483,23 +486,23 @@ struct BaseTraits NULL_TYPE = false, }; - static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { return key ^ HIGH_BIT; }; - static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { return key ^ HIGH_BIT; }; - static __host__ __device__ __forceinline__ T Max() + static HIPCUB_HOST_DEVICE __forceinline__ T Max() { UnsignedBits retval = MAX_KEY; return reinterpret_cast(retval); } - static __host__ __device__ __forceinline__ T Lowest() + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() { UnsignedBits retval = LOWEST_KEY; return reinterpret_cast(retval); @@ -512,11 +515,11 @@ struct FpLimits; template <> struct FpLimits { - static __host__ __device__ __forceinline__ float Max() { + static HIPCUB_HOST_DEVICE __forceinline__ float Max() { return std::numeric_limits::max(); } - static __host__ __device__ __forceinline__ float Lowest() { + static HIPCUB_HOST_DEVICE __forceinline__ float Lowest() { return std::numeric_limits::max() * float(-1); } }; @@ -524,11 +527,11 @@ struct FpLimits template <> struct FpLimits { - static __host__ __device__ __forceinline__ double Max() { + static HIPCUB_HOST_DEVICE __forceinline__ double Max() { return std::numeric_limits::max(); } - static __host__ __device__ __forceinline__ double Lowest() { + static HIPCUB_HOST_DEVICE __forceinline__ double Lowest() { return std::numeric_limits::max() * double(-1); } }; @@ -536,12 +539,12 @@ struct FpLimits template <> struct FpLimits<__half> { - static __host__ __device__ __forceinline__ __half Max() { + static HIPCUB_HOST_DEVICE __forceinline__ __half Max() { unsigned short max_word = 0x7BFF; return reinterpret_cast<__half&>(max_word); } - static __host__ __device__ __forceinline__ __half Lowest() { + static HIPCUB_HOST_DEVICE __forceinline__ __half Lowest() { unsigned short lowest_word = 0xFBFF; return reinterpret_cast<__half&>(lowest_word); } @@ -550,12 +553,12 @@ struct FpLimits<__half> template <> struct FpLimits { - static __host__ __device__ __forceinline__ hip_bfloat16 Max() { + static HIPCUB_HOST_DEVICE __forceinline__ hip_bfloat16 Max() { unsigned short max_word = 0x7F7F; return reinterpret_cast(max_word); } - static __host__ __device__ __forceinline__ hip_bfloat16 Lowest() { + static HIPCUB_HOST_DEVICE __forceinline__ hip_bfloat16 Lowest() { unsigned short lowest_word = 0xFF7F; return reinterpret_cast(lowest_word); } @@ -580,23 +583,23 @@ struct BaseTraits NULL_TYPE = false, }; - static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; return key ^ mask; }; - static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); return key ^ mask; }; - static __host__ __device__ __forceinline__ T Max() { + static HIPCUB_HOST_DEVICE __forceinline__ T Max() { return FpLimits::Max(); } - static __host__ __device__ __forceinline__ T Lowest() { + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() { return FpLimits::Lowest(); } }; @@ -633,7 +636,7 @@ template <> struct NumericTraits : BaseTraits -struct Traits : NumericTraits::Type> {}; +struct Traits : NumericTraits::type> {}; #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/hipcub/include/hipcub/backend/rocprim/warp/warp_exchange.hpp b/hipcub/include/hipcub/backend/rocprim/warp/warp_exchange.hpp index 3efaef9c..8ca2f85c 100644 --- a/hipcub/include/hipcub/backend/rocprim/warp/warp_exchange.hpp +++ b/hipcub/include/hipcub/backend/rocprim/warp/warp_exchange.hpp @@ -33,6 +33,8 @@ #include "../../../config.hpp" #include "../util_type.hpp" +#include + BEGIN_HIPCUB_NAMESPACE template < @@ -43,40 +45,20 @@ template < > class WarpExchange { - static_assert(PowerOfTwo::VALUE, - "LOGICAL_WARP_THREADS must be a power of two"); - - constexpr static int SMEM_BANKS = ::rocprim::detail::get_lds_banks_no(); - - constexpr static bool HAS_BANK_CONFLICTS = - ITEMS_PER_THREAD > 4 && PowerOfTwo::VALUE; - - constexpr static int BANK_CONFLICTS_PADDING = - HAS_BANK_CONFLICTS ? (ITEMS_PER_THREAD / SMEM_BANKS) : 0; + using base_type = typename rocprim::warp_exchange; - constexpr static int ITEMS_PER_TILE = - ITEMS_PER_THREAD * LOGICAL_WARP_THREADS + BANK_CONFLICTS_PADDING; - - constexpr static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == - HIPCUB_DEVICE_WARP_THREADS; +public: + using TempStorage = typename base_type::storage_type; - union _TempStorage - { - InputT items_shared[ITEMS_PER_TILE]; - }; +private: + TempStorage &temp_storage; - _TempStorage &temp_storage; - unsigned lane_id; - public: - struct TempStorage : Uninitialized<_TempStorage> {}; - WarpExchange() = delete; explicit HIPCUB_DEVICE __forceinline__ WarpExchange(TempStorage &temp_storage) : - temp_storage(temp_storage.Alias()), - lane_id(IS_ARCH_WARP ? LaneId() : LaneId() % LOGICAL_WARP_THREADS) + temp_storage(temp_storage) { } @@ -86,20 +68,8 @@ class WarpExchange const InputT (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD]) { - for (int item = 0; item < ITEMS_PER_THREAD; ++item) - { - const int idx = ITEMS_PER_THREAD * lane_id + item; - temp_storage.items_shared[idx] = input_items[item]; - } - - // member mask is unused in rocPRIM - WARP_SYNC(0); - - for (int item = 0; item < ITEMS_PER_THREAD; ++item) - { - const int idx = LOGICAL_WARP_THREADS * item + lane_id; - output_items[item] = temp_storage.items_shared[idx]; - } + base_type rocprim_warp_exchange; + rocprim_warp_exchange.blocked_to_striped(input_items, output_items, temp_storage); } template @@ -108,20 +78,8 @@ class WarpExchange const InputT (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD]) { - for (int item = 0; item < ITEMS_PER_THREAD; ++item) - { - const int idx = LOGICAL_WARP_THREADS * item + lane_id; - temp_storage.items_shared[idx] = input_items[item]; - } - - // member mask is unused in rocPRIM - WARP_SYNC(0); - - for (int item = 0; item < ITEMS_PER_THREAD; ++item) - { - const int idx = ITEMS_PER_THREAD * lane_id + item; - output_items[item] = temp_storage.items_shared[idx]; - } + base_type rocprim_warp_exchange; + rocprim_warp_exchange.striped_to_blocked(input_items, output_items, temp_storage); } template @@ -141,21 +99,8 @@ class WarpExchange OutputT (&output_items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD]) { - ROCPRIM_UNROLL - for (int item = 0; item < ITEMS_PER_THREAD; ++item) - { - temp_storage.items_shared[ranks[item]] = input_items[item]; - } - - // member mask is unused in rocPRIM - WARP_SYNC(0); - - ROCPRIM_UNROLL - for (int item = 0; item < ITEMS_PER_THREAD; item++) - { - int item_offset = (item * LOGICAL_WARP_THREADS) + lane_id; - output_items[item] = temp_storage.items_shared[item_offset]; - } + base_type rocprim_warp_exchange; + rocprim_warp_exchange.scatter_to_striped(input_items, output_items, ranks, temp_storage); } }; diff --git a/hipcub/include/hipcub/config.hpp b/hipcub/include/hipcub/config.hpp index 16b03023..80a24cc0 100644 --- a/hipcub/include/hipcub/config.hpp +++ b/hipcub/include/hipcub/config.hpp @@ -67,6 +67,21 @@ #define HIPCUB_HOST_DEVICE __host__ __device__ #define HIPCUB_SHARED_MEMORY __shared__ +// Helper macros to disable warnings in clang +#ifdef __clang__ +#define HIPCUB_PRAGMA_TO_STR(x) _Pragma(#x) +#define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define HIPCUB_CLANG_SUPPRESS_WARNING(w) HIPCUB_PRAGMA_TO_STR(clang diagnostic ignored w) +#define HIPCUB_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + HIPCUB_CLANG_SUPPRESS_WARNING_PUSH HIPCUB_CLANG_SUPPRESS_WARNING(w) +#else // __clang__ +#define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH +#define HIPCUB_CLANG_SUPPRESS_WARNING(w) +#define HIPCUB_CLANG_SUPPRESS_WARNING_POP +#define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // __clang__ + BEGIN_HIPCUB_NAMESPACE /// hipCUB error reporting macro (prints error messages to stderr) diff --git a/hipcub/include/hipcub/device/device_adjacent_difference.hpp b/hipcub/include/hipcub/device/device_adjacent_difference.hpp new file mode 100644 index 00000000..9b6740df --- /dev/null +++ b/hipcub/include/hipcub/device/device_adjacent_difference.hpp @@ -0,0 +1,38 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define HIPCUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#ifdef __HIP_PLATFORM_AMD__ + #include "../backend/rocprim/device/device_adjacent_difference.hpp" +#elif defined(__HIP_PLATFORM_NVIDIA__) + #include "../backend/cub/device/device_adjacent_difference.hpp" +#endif + +#endif // HIPCUB_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ diff --git a/test/extra/CMakeLists.txt b/test/extra/CMakeLists.txt index ab81765e..68d7b515 100644 --- a/test/extra/CMakeLists.txt +++ b/test/extra/CMakeLists.txt @@ -56,55 +56,54 @@ include(VerifyCompiler) include(DownloadProject) if(HIP_COMPILER STREQUAL "nvcc") file( - DOWNLOAD https://github.com/NVlabs/cub/archive/1.15.0.zip - ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip + DOWNLOAD https://github.com/NVlabs/cub/archive/1.16.0.zip + ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip STATUS cub_download_status LOG cub_download_log ) list(GET cub_download_status 0 cub_download_error_code) if(cub_download_error_code) message(FATAL_ERROR "Error: downloading " - "https://github.com/NVlabs/cub/archive/1.15.0.zip failed " + "https://github.com/NVlabs/cub/archive/1.16.0.zip failed " "error_code: ${cub_download_error_code} " "log: ${cub_download_log} " ) endif() execute_process( - COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip + COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} RESULT_VARIABLE cub_unpack_error_code ) if(cub_unpack_error_code) - message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0.zip failed") + message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0.zip failed") endif() - set(CUB_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub-1.15.0/ CACHE PATH "" FORCE) + set(CUB_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub-1.16.0/ CACHE PATH "" FORCE) message(STATUS "CUB_INCLUDE_DIR: ${CUB_INCLUDE_DIR}") - if(NOT DEFINED THRUST_INCLUDE_DIR) - file( - DOWNLOAD https://github.com/NVIDIA/thrust/archive/1.15.0.zip - ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip - STATUS thrust_download_status LOG thrust_download_log - ) - list(GET thrust_download_status 0 thrust_download_error_code) - if(thrust_download_error_code) - message(FATAL_ERROR "Error: downloading " - "https://github.com/NVIDIA/thrust/archive/1.15.0.zip failed " - "error_code: ${thrust_download_error_code} " - "log: ${thrust_download_log} " - ) - endif() - - execute_process( - COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - RESULT_VARIABLE thrust_unpack_error_code + file( + DOWNLOAD https://github.com/NVIDIA/thrust/archive/1.16.0.zip + ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip + STATUS thrust_download_status LOG thrust_download_log + ) + list(GET thrust_download_status 0 thrust_download_error_code) + if(thrust_download_error_code) + message(FATAL_ERROR "Error: downloading " + "https://github.com/NVIDIA/thrust/archive/1.16.0.zip failed " + "error_code: ${thrust_download_error_code} " + "log: ${thrust_download_log} " ) - if(thrust_unpack_error_code) - message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0.zip failed") - endif() - set(THRUST_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.15.0/ CACHE PATH "") endif() + + execute_process( + COMMAND ${CMAKE_COMMAND} -E tar xzf ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + RESULT_VARIABLE thrust_unpack_error_code + ) + if(thrust_unpack_error_code) + message(FATAL_ERROR "Error: unpacking ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0.zip failed") + endif() + set(THRUST_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/thrust-1.16.0/ CACHE PATH "" FORCE) + message(STATUS "THRUST_INCLUDE_DIR: ${THRUST_INCLUDE_DIR}") endif() # Download rocPRIM (only for ROCm platform) diff --git a/test/hipcub/CMakeLists.txt b/test/hipcub/CMakeLists.txt index 2170db86..7837d144 100644 --- a/test/hipcub/CMakeLists.txt +++ b/test/hipcub/CMakeLists.txt @@ -86,6 +86,7 @@ add_hipcub_test("hipcub.BlockReduce" test_hipcub_block_reduce.cpp) add_hipcub_test("hipcub.BlockRunLengthDecode" test_hipcub_block_run_length_decode.cpp) add_hipcub_test("hipcub.BlockScan" test_hipcub_block_scan.cpp) add_hipcub_test("hipcub.BlockShuffle" test_hipcub_block_shuffle.cpp) +add_hipcub_test("hipcub.DeviceAdjacentDifference" test_hipcub_device_adjacent_difference.cpp) add_hipcub_test("hipcub.DeviceHistogram" test_hipcub_device_histogram.cpp) add_hipcub_test("hipcub.DeviceMergeSort" test_hipcub_device_merge_sort.cpp) add_hipcub_test("hipcub.DeviceRadixSort" test_hipcub_device_radix_sort.cpp) diff --git a/test/hipcub/half.hpp b/test/hipcub/half.hpp index adabeccb..8f2cba45 100644 --- a/test/hipcub/half.hpp +++ b/test/hipcub/half.hpp @@ -174,7 +174,7 @@ struct half_t { if (mantissa) { - f = 0x7fffffff; // not a number + f = 0x7fffffff | (sign << 31); // not a number } else { diff --git a/test/hipcub/test_hipcub_block_adjacent_difference.cpp b/test/hipcub/test_hipcub_block_adjacent_difference.cpp index 11e16672..166f054b 100644 --- a/test/hipcub/test_hipcub_block_adjacent_difference.cpp +++ b/test/hipcub/test_hipcub_block_adjacent_difference.cpp @@ -139,6 +139,8 @@ void flag_heads_kernel(Type* device_input, long long* device_heads) hipcub::BlockAdjacentDifference bAdjacentDiff; FlagType head_flags[ItemsPerThread]; + + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") if(hipBlockIdx_x % 2 == 1) { const Type tile_predecessor_item = device_input[block_offset - 1]; @@ -148,6 +150,7 @@ void flag_heads_kernel(Type* device_input, long long* device_heads) { bAdjacentDiff.FlagHeads(head_flags, input, FlagOpType()); } + HIPCUB_CLANG_SUPPRESS_WARNING_POP hipcub::StoreDirectBlocked(lid, device_heads + block_offset, head_flags); } @@ -173,6 +176,8 @@ void flag_tails_kernel(Type* device_input, long long* device_tails) hipcub::BlockAdjacentDifference bAdjacentDiff; FlagType tail_flags[ItemsPerThread]; + + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") if(hipBlockIdx_x % 2 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; @@ -182,6 +187,7 @@ void flag_tails_kernel(Type* device_input, long long* device_tails) { bAdjacentDiff.FlagTails(tail_flags, input, FlagOpType()); } + HIPCUB_CLANG_SUPPRESS_WARNING_POP hipcub::StoreDirectBlocked(lid, device_tails + block_offset, tail_flags); } @@ -208,6 +214,8 @@ void flag_heads_and_tails_kernel(Type* device_input, long long* device_heads, lo FlagType head_flags[ItemsPerThread]; FlagType tail_flags[ItemsPerThread]; + + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") if(hipBlockIdx_x % 4 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; @@ -228,11 +236,210 @@ void flag_heads_and_tails_kernel(Type* device_input, long long* device_heads, lo { bAdjacentDiff.FlagHeadsAndTails(head_flags, tail_flags, input, FlagOpType()); } + HIPCUB_CLANG_SUPPRESS_WARNING_POP hipcub::StoreDirectBlocked(lid, device_heads + block_offset, head_flags); hipcub::StoreDirectBlocked(lid, device_tails + block_offset, tail_flags); } +template< + class T, + class Output, + class BinaryFunction, + unsigned int BlockSize, + unsigned int ItemsPerThread +> +struct params_subtract +{ + using type = T; + using output = Output; + using binary_function = BinaryFunction; + static constexpr unsigned int block_size = BlockSize; + static constexpr unsigned int items_per_thread = ItemsPerThread; +}; + +template +class HipcubBlockAdjacentDifferenceSubtract : public ::testing::Test { +public: + using params_subtract = ParamsSubtract; +}; + +struct custom_op1 +{ + template + HIPCUB_HOST_DEVICE + T operator()(const T& a, const T& b) const + { + return a - b; + } +}; + +struct custom_op2 +{ + template + HIPCUB_HOST_DEVICE + T operator()(const T& a, const T& b) const + { + return (b + b) - a; + } +}; + +typedef ::testing::Types< + params_subtract, + params_subtract, + params_subtract, + params_subtract, + + params_subtract, + params_subtract, + params_subtract, + params_subtract, + + params_subtract, + params_subtract, + params_subtract, + params_subtract, + + params_subtract, + params_subtract, + params_subtract, + params_subtract +> ParamsSubtract; + +TYPED_TEST_SUITE(HipcubBlockAdjacentDifferenceSubtract, ParamsSubtract); + +template< + typename T, + typename Output, + typename StorageType, + typename BinaryFunction, + unsigned int BlockSize, + unsigned int ItemsPerThread +> +__global__ +__launch_bounds__(BlockSize) +void subtract_left_kernel(const T* input, StorageType* output) +{ + const unsigned int lid = threadIdx.x; + const unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int block_offset = blockIdx.x * items_per_block; + + T thread_items[ItemsPerThread]; + hipcub::LoadDirectBlocked(lid, input + block_offset, thread_items); + + hipcub::BlockAdjacentDifference adjacent_difference; + + Output thread_output[ItemsPerThread]; + + if (blockIdx.x % 2 == 1) + { + const T tile_predecessor_item = input[block_offset - 1]; + adjacent_difference.SubtractLeft(thread_items, thread_output, BinaryFunction{}, tile_predecessor_item); + } + else + { + adjacent_difference.SubtractLeft(thread_items, thread_output, BinaryFunction{}); + } + + hipcub::StoreDirectBlocked(lid, output + block_offset, thread_output); +} + +template< + typename T, + typename Output, + typename StorageType, + typename BinaryFunction, + unsigned int BlockSize, + unsigned int ItemsPerThread +> +__global__ +__launch_bounds__(BlockSize) +void subtract_left_partial_tile_kernel(const T* input, int* tile_sizes, StorageType* output) +{ + const unsigned int lid = threadIdx.x; + const unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int block_offset = blockIdx.x * items_per_block; + + T thread_items[ItemsPerThread]; + hipcub::LoadDirectBlocked(lid, input + block_offset, thread_items); + + hipcub::BlockAdjacentDifference adjacent_difference; + + Output thread_output[ItemsPerThread]; + + int tile_size = tile_sizes[blockIdx.x]; + + adjacent_difference.SubtractLeftPartialTile(thread_items, thread_output, BinaryFunction{}, tile_size); + + hipcub::StoreDirectBlocked(lid, output + block_offset, thread_output); +} + +template< + typename T, + typename Output, + typename StorageType, + typename BinaryFunction, + unsigned int BlockSize, + unsigned int ItemsPerThread +> +__global__ +__launch_bounds__(BlockSize) +void subtract_right_kernel(const T* input, StorageType* output) +{ + const unsigned int lid = threadIdx.x; + const unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int block_offset = blockIdx.x * items_per_block; + + T thread_items[ItemsPerThread]; + hipcub::LoadDirectBlocked(lid, input + block_offset, thread_items); + + hipcub::BlockAdjacentDifference adjacent_difference; + + Output thread_output[ItemsPerThread]; + + if (blockIdx.x % 2 == 0) + { + const T tile_successor_item = input[block_offset + items_per_block]; + adjacent_difference.SubtractRight(thread_items, thread_output, BinaryFunction{}, tile_successor_item); + } + else + { + adjacent_difference.SubtractRight(thread_items, thread_output, BinaryFunction{}); + } + + hipcub::StoreDirectBlocked(lid, output + block_offset, thread_output); +} + +template< + typename T, + typename Output, + typename StorageType, + typename BinaryFunction, + unsigned int BlockSize, + unsigned int ItemsPerThread +> +__global__ +__launch_bounds__(BlockSize) +void subtract_right_partial_tile_kernel(const T* input, int* tile_sizes, StorageType* output) +{ + const unsigned int lid = threadIdx.x; + const unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int block_offset = blockIdx.x * items_per_block; + + T thread_items[ItemsPerThread]; + hipcub::LoadDirectBlocked(lid, input + block_offset, thread_items); + + hipcub::BlockAdjacentDifference adjacent_difference; + + Output thread_output[ItemsPerThread]; + + int tile_size = tile_sizes[blockIdx.x]; + + adjacent_difference.SubtractRightPartialTile(thread_items, thread_output, BinaryFunction{}, tile_size); + + hipcub::StoreDirectBlocked(lid, output + block_offset, thread_output); +} + TYPED_TEST(HipcubBlockAdjacentDifference, FlagHeads) { using type = typename TestFixture::params::type; @@ -562,3 +769,421 @@ TYPED_TEST(HipcubBlockAdjacentDifference, FlagHeadsAndTails) } } + +TYPED_TEST(HipcubBlockAdjacentDifferenceSubtract, SubtractLeft) +{ + using type = typename TestFixture::params_subtract::type; + using binary_function = typename TestFixture::params_subtract::binary_function; + + using output_type = typename TestFixture::params_subtract::output; + + using stored_type = std::conditional_t::value, int, output_type>; + + constexpr size_t block_size = TestFixture::params_subtract::block_size; + constexpr size_t items_per_thread = TestFixture::params_subtract::items_per_thread; + static constexpr int items_per_block = block_size * items_per_thread; + static constexpr int size = items_per_block * 20; + static constexpr int grid_size = size / items_per_block; + + // Given block size not supported + if(block_size > test_utils::get_max_block_size()) + { + return; + } + + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + // Generate data + const std::vector input = test_utils::get_random_data(size, 0, 10, seed_value); + std::vector output(size); + + // Calculate expected results on host + std::vector expected(size); + binary_function op; + + for(size_t block_index = 0; block_index < grid_size; ++block_index) + { + for(unsigned int item = 0; item < items_per_block; ++item) + { + const size_t i = block_index * items_per_block + item; + if(item == 0) + { + expected[i] + = static_cast(block_index % 2 == 1 ? op(input[i], input[i - 1]) : input[i]); + } + else + { + expected[i] = static_cast(op(input[i], input[i - 1])); + } + } + } + + // Preparing Device + type* d_input; + stored_type* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(output[0]))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + + // Running kernel + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + subtract_left_kernel + ), + dim3(grid_size), dim3(block_size), 0, 0, + d_input, d_output + ); + HIP_CHECK(hipGetLastError()); + + // Reading results + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost + ) + ); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( + output, expected, test_utils::precision_threshold::percentage)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + } +} + +TYPED_TEST(HipcubBlockAdjacentDifferenceSubtract, SubtractLeftPartialTile) +{ + using type = typename TestFixture::params_subtract::type; + using binary_function = typename TestFixture::params_subtract::binary_function; + + using output_type = typename TestFixture::params_subtract::output; + + using stored_type = std::conditional_t::value, int, output_type>; + + constexpr size_t block_size = TestFixture::params_subtract::block_size; + constexpr size_t items_per_thread = TestFixture::params_subtract::items_per_thread; + static constexpr int items_per_block = block_size * items_per_thread; + static constexpr int size = items_per_block * 20; + static constexpr int grid_size = size / items_per_block; + + // Given block size not supported + if(block_size > test_utils::get_max_block_size()) + { + return; + } + + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + // Generate data + const std::vector input = test_utils::get_random_data(size, 0, 10, seed_value); + std::vector output(size); + + const std::vector tile_sizes + = test_utils::get_random_data(grid_size, 0, items_per_block, seed_value); + + // Calculate expected results on host + std::vector expected(size); + binary_function op; + + for(size_t block_index = 0; block_index < grid_size; ++block_index) + { + for(int item = 0; item < items_per_block; ++item) + { + const size_t i = block_index * items_per_block + item; + if (item < tile_sizes[block_index]) + { + if(item == 0) + { + expected[i] = static_cast(input[i]); + } + else + { + expected[i] = static_cast(op(input[i], input[i - 1])); + } + } + else + { + expected[i] = static_cast(input[i]); + } + } + } + + // Preparing Device + type* d_input; + int* d_tile_sizes; + stored_type* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_tile_sizes, tile_sizes.size() * sizeof(tile_sizes[0]))); + HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(output[0]))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_tile_sizes, tile_sizes.data(), + tile_sizes.size() * sizeof(tile_sizes[0]), + hipMemcpyHostToDevice + ) + ); + + // Running kernel + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + subtract_left_partial_tile_kernel + ), + dim3(grid_size), dim3(block_size), 0, 0, + d_input, d_tile_sizes, d_output + ); + HIP_CHECK(hipGetLastError()); + + // Reading results + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost + ) + ); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( + output, expected, test_utils::precision_threshold::percentage)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_tile_sizes)); + HIP_CHECK(hipFree(d_output)); + } +} + +TYPED_TEST(HipcubBlockAdjacentDifferenceSubtract, SubtractRight) +{ + using type = typename TestFixture::params_subtract::type; + using binary_function = typename TestFixture::params_subtract::binary_function; + + using output_type = typename TestFixture::params_subtract::output; + + using stored_type = std::conditional_t::value, int, output_type>; + + constexpr size_t block_size = TestFixture::params_subtract::block_size; + constexpr size_t items_per_thread = TestFixture::params_subtract::items_per_thread; + static constexpr int items_per_block = block_size * items_per_thread; + static constexpr int size = items_per_block * 20; + static constexpr int grid_size = size / items_per_block; + + // Given block size not supported + if(block_size > test_utils::get_max_block_size()) + { + return; + } + + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + // Generate data + const std::vector input = test_utils::get_random_data(size, 0, 10, seed_value); + std::vector output(size); + + // Calculate expected results on host + std::vector expected(size); + binary_function op; + + for(size_t block_index = 0; block_index < grid_size; ++block_index) + { + for(int item = 0; item < items_per_block; ++item) + { + const size_t i = block_index * items_per_block + item; + if(item == items_per_block - 1) + { + expected[i] + = static_cast(block_index % 2 == 0 ? op(input[i], input[i + 1]) : input[i]); + } + else + { + expected[i] = static_cast(op(input[i], input[i + 1])); + } + } + } + + // Preparing Device + type* d_input; + stored_type* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(output[0]))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + + // Running kernel + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + subtract_right_kernel + ), + dim3(grid_size), dim3(block_size), 0, 0, + d_input, d_output + ); + HIP_CHECK(hipGetLastError()); + + // Reading results + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost + ) + ); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( + output, expected, test_utils::precision_threshold::percentage)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + } +} + +TYPED_TEST(HipcubBlockAdjacentDifferenceSubtract, SubtractRightPartialTile) +{ + using type = typename TestFixture::params_subtract::type; + using binary_function = typename TestFixture::params_subtract::binary_function; + + using output_type = typename TestFixture::params_subtract::output; + + using stored_type = std::conditional_t::value, int, output_type>; + + constexpr size_t block_size = TestFixture::params_subtract::block_size; + constexpr size_t items_per_thread = TestFixture::params_subtract::items_per_thread; + static constexpr int items_per_block = block_size * items_per_thread; + static constexpr int size = items_per_block * 20; + static constexpr int grid_size = size / items_per_block; + + // Given block size not supported + if(block_size > test_utils::get_max_block_size()) + { + return; + } + + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + // Generate data + const std::vector input = test_utils::get_random_data(size, 0, 10, seed_value); + std::vector output(size); + + const std::vector tile_sizes + = test_utils::get_random_data(grid_size, 0, items_per_block, seed_value); + + // Calculate expected results on host + std::vector expected(size); + binary_function op; + + for(size_t block_index = 0; block_index < grid_size; ++block_index) + { + for(int item = 0; item < items_per_block; ++item) + { + const size_t i = block_index * items_per_block + item; + if (item < tile_sizes[block_index]) + { + if(item == tile_sizes[block_index] - 1 || item == items_per_block - 1) + { + expected[i] = static_cast(input[i]); + } + else + { + expected[i] = static_cast(op(input[i], input[i + 1])); + } + } + else + { + expected[i] = static_cast(input[i]); + } + } + } + + // Preparing Device + type* d_input; + int* d_tile_sizes; + stored_type* d_output; + HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); + HIP_CHECK(hipMalloc(&d_tile_sizes, tile_sizes.size() * sizeof(tile_sizes[0]))); + HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(output[0]))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(input[0]), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_tile_sizes, tile_sizes.data(), + tile_sizes.size() * sizeof(tile_sizes[0]), + hipMemcpyHostToDevice + ) + ); + + // Running kernel + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + subtract_right_partial_tile_kernel + ), + dim3(grid_size), dim3(block_size), 0, 0, + d_input, d_tile_sizes, d_output + ); + HIP_CHECK(hipGetLastError()); + + // Reading results + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost + ) + ); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( + output, expected, test_utils::precision_threshold::percentage)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_tile_sizes)); + HIP_CHECK(hipFree(d_output)); + } +} \ No newline at end of file diff --git a/test/hipcub/test_hipcub_block_radix_rank.cpp b/test/hipcub/test_hipcub_block_radix_rank.cpp index 01949d56..525be359 100644 --- a/test/hipcub/test_hipcub_block_radix_rank.cpp +++ b/test/hipcub/test_hipcub_block_radix_rank.cpp @@ -72,7 +72,7 @@ class BlockRadixSort #ifdef __HIP_PLATFORM_AMD__ KEYS_ONLY = rocprim::Equals::VALUE, #else - KEYS_ONLY = cub::Equals::VALUE, + KEYS_ONLY = std::is_same::value, #endif }; diff --git a/test/hipcub/test_hipcub_device_adjacent_difference.cpp b/test/hipcub/test_hipcub_device_adjacent_difference.cpp new file mode 100644 index 00000000..9c55b2a8 --- /dev/null +++ b/test/hipcub/test_hipcub_device_adjacent_difference.cpp @@ -0,0 +1,233 @@ +// MIT License +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "common_test_header.hpp" + +// hipcub API +#include "hipcub/device/device_adjacent_difference.hpp" + +template +hipError_t dispatch_adjacent_difference(std::true_type /*left*/, + std::true_type /*copy*/, + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractLeftCopy( + d_temp_storage, temp_storage_bytes, d_input, d_output, + std::forward(args)... + ); +} + +template +hipError_t dispatch_adjacent_difference(std::true_type /*left*/, + std::false_type /*copy*/, + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT /*d_output*/, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, d_input, + std::forward(args)... + ); +} + +template +hipError_t dispatch_adjacent_difference(std::false_type /*left*/, + std::true_type /*copy*/, + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractRightCopy( + d_temp_storage, temp_storage_bytes, d_input, d_output, + std::forward(args)... + ); +} + +template +hipError_t dispatch_adjacent_difference(std::false_type /*left*/, + std::false_type /*copy*/, + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT /*d_output*/, + Args&&... args) +{ + return ::hipcub::DeviceAdjacentDifference::SubtractRight( + d_temp_storage, temp_storage_bytes, d_input, + std::forward(args)... + ); +} + +template +auto get_expected_result(const std::vector& input, + const BinaryFunction op, + std::true_type /*left*/) +{ + std::vector result(input.size()); + std::adjacent_difference(input.cbegin(), input.cend(), result.begin(), op); + return result; +} + +template +auto get_expected_result(const std::vector& input, + const BinaryFunction op, + std::false_type /*left*/) +{ + std::vector result(input.size()); + // "right" adjacent difference is just adjacent difference backwards + std::adjacent_difference(input.crbegin(), input.crend(), result.rbegin(), op); + return result; +} + +template< + class InputT, + class OutputT = InputT, + bool Left = true, + bool Copy = true +> +struct params +{ + using input_type = InputT; + using output_type = OutputT; + static constexpr bool left = Left; + static constexpr bool copy = Copy; +}; + +template +class HipcubDeviceAdjacentDifference : public ::testing::Test { +public: + using params = Params; +}; + +typedef ::testing::Types< + params, + params, + params, + params, + params +> Params; + +std::vector get_sizes() +{ + std::vector sizes = { 1, 10, 53, 211, 1024, 2345, 4096, 34567, (1 << 16) - 1220, (1 << 23) - 76543 }; + const std::vector random_sizes = test_utils::get_random_data(10, 1, 100000, rand()); + sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); + return sizes; +} + +TYPED_TEST_SUITE(HipcubDeviceAdjacentDifference, Params); + +TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy) +{ + using input_type = typename TestFixture::params::input_type; + static constexpr std::integral_constant left_constant{}; + static constexpr std::integral_constant copy_constant{}; + using output_type = std::conditional_t; + static constexpr hipStream_t stream = 0; + static constexpr bool debug_synchronous = false; + static constexpr ::hipcub::Difference op; + + const auto sizes = get_sizes(); + for (size_t size : sizes) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + const auto input = test_utils::get_random_data( + size, + static_cast(-50), + static_cast(50), + seed_value + ); + + input_type * d_input{}; + output_type * d_output{}; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(d_input[0]))); + if (copy_constant) + { + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(d_output[0]))); + } + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + size * sizeof(input_type), + hipMemcpyHostToDevice + ) + ); + + const auto expected = get_expected_result(input, op, left_constant); + + size_t temporary_storage_bytes = 0; + HIP_CHECK( + dispatch_adjacent_difference( + left_constant, copy_constant, + nullptr, temporary_storage_bytes, + d_input, d_output, size, op, stream, debug_synchronous + ) + ); + +#ifdef __HIP_PLATFORM_AMD__ + ASSERT_GT(temporary_storage_bytes, 0U); +#endif + + void * d_temporary_storage; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK( + dispatch_adjacent_difference( + left_constant, copy_constant, + d_temporary_storage, temporary_storage_bytes, + d_input, d_output, size, op, stream, debug_synchronous + ) + ); + + std::vector output(size); + HIP_CHECK( + hipMemcpy( + output.data(), copy_constant ? d_output : d_input, + size * sizeof(output[0]), + hipMemcpyDeviceToHost + ) + ); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_input)); + if (copy_constant) + { + HIP_CHECK(hipFree(d_output)); + } + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(output, expected)); + } + } +} diff --git a/test/hipcub/test_hipcub_device_merge_sort.cpp b/test/hipcub/test_hipcub_device_merge_sort.cpp index a675f957..41580eda 100644 --- a/test/hipcub/test_hipcub_device_merge_sort.cpp +++ b/test/hipcub/test_hipcub_device_merge_sort.cpp @@ -99,8 +99,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeys) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); key_type * d_keys_input; HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); @@ -173,8 +173,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeysCopy) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); key_type * d_keys_input; key_type * d_keys_output; @@ -252,8 +252,8 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeys) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); key_type * d_keys_input; HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); @@ -332,8 +332,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairs) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); @@ -456,8 +456,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairsCopy) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); @@ -591,8 +591,8 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortPairs) // Generate data std::vector keys_input; keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition); std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); diff --git a/test/hipcub/test_hipcub_device_radix_sort.cpp b/test/hipcub/test_hipcub_device_radix_sort.cpp index 5608d32e..eee7c62c 100644 --- a/test/hipcub/test_hipcub_device_radix_sort.cpp +++ b/test/hipcub/test_hipcub_device_radix_sort.cpp @@ -33,7 +33,7 @@ template< bool Descending = false, unsigned int StartBit = 0, unsigned int EndBit = sizeof(Key) * 8, - bool CheckHugeSizes = false + bool CheckLargeSizes = false > struct params { @@ -42,7 +42,7 @@ struct params static constexpr bool descending = Descending; static constexpr unsigned int start_bit = StartBit; static constexpr unsigned int end_bit = EndBit; - static constexpr bool check_huge_sizes = CheckHugeSizes; + static constexpr bool check_large_sizes = CheckLargeSizes; }; template @@ -79,16 +79,16 @@ typedef ::testing::Types< params, params, - // huge sizes to check correctness of more than 1 block per batch + // large sizes to check correctness of more than 1 block per batch params > Params; TYPED_TEST_SUITE(HipcubDeviceRadixSort, Params); -std::vector get_sizes() +std::vector get_sizes() { - std::vector sizes = { 1, 10, 53, 211, 1024, 2345, 4096, 34567, (1 << 16) - 1220, (1 << 23) - 76543 }; - const std::vector random_sizes = test_utils::get_random_data(10, 1, 100000, rand()); + std::vector sizes = { 1, 10, 53, 211, 1024, 2345, 4096, 34567, (1 << 16) - 1220, (1 << 23) - 76543 }; + const std::vector random_sizes = test_utils::get_random_data(10, 1, 100000, rand()); sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); return sizes; } @@ -99,16 +99,16 @@ TYPED_TEST(HipcubDeviceRadixSort, SortKeys) constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; constexpr unsigned int end_bit = TestFixture::params::end_bit; - constexpr bool check_huge_sizes = TestFixture::params::check_huge_sizes; + constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + const std::vector sizes = get_sizes(); + for(unsigned int size : sizes) { - if(size > (1 << 20) && !check_huge_sizes) continue; + if(size > (1 << 20) && !check_large_sizes) continue; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -120,8 +120,8 @@ TYPED_TEST(HipcubDeviceRadixSort, SortKeys) std::vector keys_input; keys_input = test_utils::get_random_data( size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition ); test_utils::add_special_values(keys_input, seed_value); @@ -205,16 +205,16 @@ TYPED_TEST(HipcubDeviceRadixSort, SortPairs) constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; constexpr unsigned int end_bit = TestFixture::params::end_bit; - constexpr bool check_huge_sizes = TestFixture::params::check_huge_sizes; + constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + const std::vector sizes = get_sizes(); + for(unsigned int size : sizes) { - if(size > (1 << 20) && !check_huge_sizes) continue; + if(size > (1 << 20) && !check_large_sizes) continue; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -226,8 +226,8 @@ TYPED_TEST(HipcubDeviceRadixSort, SortPairs) std::vector keys_input; keys_input = test_utils::get_random_data( size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition ); test_utils::add_special_values(keys_input, seed_value); @@ -353,16 +353,16 @@ TYPED_TEST(HipcubDeviceRadixSort, SortKeysDoubleBuffer) constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; constexpr unsigned int end_bit = TestFixture::params::end_bit; - constexpr bool check_huge_sizes = TestFixture::params::check_huge_sizes; + constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + const std::vector sizes = get_sizes(); + for(unsigned int size : sizes) { - if(size > (1 << 20) && !check_huge_sizes) continue; + if(size > (1 << 20) && !check_large_sizes) continue; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -374,8 +374,8 @@ TYPED_TEST(HipcubDeviceRadixSort, SortKeysDoubleBuffer) std::vector keys_input; keys_input = test_utils::get_random_data( size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition ); test_utils::add_special_values(keys_input, seed_value); @@ -460,16 +460,16 @@ TYPED_TEST(HipcubDeviceRadixSort, SortPairsDoubleBuffer) constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; constexpr unsigned int end_bit = TestFixture::params::end_bit; - constexpr bool check_huge_sizes = TestFixture::params::check_huge_sizes; + constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + const std::vector sizes = get_sizes(); + for(unsigned int size : sizes) { - if(size > (1 << 20) && !check_huge_sizes) continue; + if(size > (1 << 20) && !check_large_sizes) continue; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -481,8 +481,8 @@ TYPED_TEST(HipcubDeviceRadixSort, SortPairsDoubleBuffer) std::vector keys_input; keys_input = test_utils::get_random_data( size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value + seed_value_addition ); test_utils::add_special_values(keys_input, seed_value); @@ -603,4 +603,74 @@ TYPED_TEST(HipcubDeviceRadixSort, SortPairsDoubleBuffer) ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); } } +} + +TEST(HipcubDeviceRadixSort, SortKeysOver4G) +{ + using key_type = uint8_t; + constexpr unsigned int start_bit = 0; + constexpr unsigned int end_bit = 8ull * sizeof(key_type); + constexpr hipStream_t stream = 0; + constexpr bool debug_synchronous = false; + constexpr size_t size = (1ull << 32) + 32; + constexpr size_t number_of_possible_keys = 1ull << (8ull * sizeof(key_type)); + assert(std::is_unsigned::value); + std::vector histogram(number_of_possible_keys, 0); + const int seed_value = rand(); + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + + std::vector keys_input = test_utils::get_random_data( + size, + std::numeric_limits::min(), + std::numeric_limits::max(), + seed_value); + + //generate histogram of the randomly generated values + std::for_each(keys_input.begin(), keys_input.end(), [&](const key_type &a){ + histogram[a]++; + }); + + key_type * d_keys_input_output{}; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input_output, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + size_t temporary_storage_bytes; + HIP_CHECK( + hipcub::DeviceRadixSort::SortKeys( + nullptr, temporary_storage_bytes, + d_keys_input_output, d_keys_input_output, size, + start_bit, end_bit, + stream, debug_synchronous + ) + ); + + ASSERT_GT(temporary_storage_bytes, 0); + void * d_temporary_storage; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK( + hipcub::DeviceRadixSort::SortKeys( + d_temporary_storage, temporary_storage_bytes, + d_keys_input_output, d_keys_input_output, size, + start_bit, end_bit, + stream, debug_synchronous + ) + ); + + std::vector output(keys_input.size()); + HIP_CHECK(hipMemcpy(output.data(), d_keys_input_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); + + size_t counter = 0; + for(size_t i = 0; i <= std::numeric_limits::max(); ++i) + { + for(size_t j = 0; j < histogram[i]; ++j) + { + ASSERT_EQ(static_cast(output[counter]), i); + ++counter; + } + } + ASSERT_EQ(counter, size); + + HIP_CHECK(hipFree(d_keys_input_output)); + HIP_CHECK(hipFree(d_temporary_storage)); } \ No newline at end of file diff --git a/test/hipcub/test_hipcub_device_select.cpp b/test/hipcub/test_hipcub_device_select.cpp index 8f1eef80..0ec64ed3 100644 --- a/test/hipcub/test_hipcub_device_select.cpp +++ b/test/hipcub/test_hipcub_device_select.cpp @@ -486,3 +486,217 @@ TYPED_TEST(HipcubDeviceSelectTests, Unique) } } } + +template< + typename KeyType, + typename ValueType, + typename OutputKeyType = KeyType, + typename OutputValueType = ValueType +> +struct DeviceUniqueByKeyParams +{ + using key_type = KeyType; + using value_type = ValueType; + using output_key_type = OutputKeyType; + using output_value_type = OutputValueType; +}; + +template +class HipcubDeviceUniqueByKeyTests : public ::testing::Test +{ +public: + using key_type = typename Params::key_type; + using value_type = typename Params::value_type; + using output_key_type = typename Params::output_key_type; + using output_value_type = typename Params::output_value_type; + const bool debug_synchronous = false; +}; + +typedef ::testing::Types< + DeviceUniqueByKeyParams, + DeviceUniqueByKeyParams, + DeviceUniqueByKeyParams, + DeviceUniqueByKeyParams, test_utils::custom_test_type> +> HipcubDeviceUniqueByKeyTestsParams; + +TYPED_TEST_SUITE(HipcubDeviceUniqueByKeyTests, HipcubDeviceUniqueByKeyTestsParams); + +TYPED_TEST(HipcubDeviceUniqueByKeyTests, UniqueByKey) +{ + using key_type = typename TestFixture::key_type; + using value_type = typename TestFixture::value_type; + using output_key_type = typename TestFixture::output_key_type; + using output_value_type = typename TestFixture::output_value_type; + + const bool debug_synchronous = TestFixture::debug_synchronous; + + hipStream_t stream = 0; // default stream + + const auto sizes = get_sizes(); + const auto probabilities = get_discontinuity_probabilities(); + + for (auto size : sizes) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + for (auto p : probabilities) + { + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + SCOPED_TRACE(testing::Message() << "with p = " << p); + + // Generate data + std::vector input_keys(size); + { + std::vector input01 = test_utils::get_random_data01(size, p, seed_value); + test_utils::host_inclusive_scan( + input01.begin(), input01.end(), input_keys.begin(), hipcub::Sum() + ); + } + + const auto input_values = test_utils::get_random_data(size, -1000, 1000, seed_value); + + // Allocate and copy to device + key_type* d_keys_input; + value_type* d_values_input; + output_key_type* d_keys_output; + output_value_type* d_values_output; + + unsigned int* d_selected_count_output; + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_input, input_values.size() * sizeof(input_values[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_output, input_values.size() * sizeof(input_values[0]))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_selected_count_output, sizeof(unsigned int))); + + HIP_CHECK( + hipMemcpy( + d_keys_input, input_keys.data(), + input_keys.size() * sizeof(input_keys[0]), + hipMemcpyHostToDevice + ) + ); + + HIP_CHECK( + hipMemcpy( + d_values_input, input_values.data(), + input_values.size() * sizeof(input_values[0]), + hipMemcpyHostToDevice + ) + ); + + HIP_CHECK(hipDeviceSynchronize()); + + // Caclulate expected result on host + std::vector expected_keys; + std::vector expected_values; + expected_keys.reserve(input_keys.size()); + expected_values.reserve(input_values.size()); + expected_keys.push_back(input_keys[0]); + expected_values.push_back(input_values[0]); + + for (size_t i = 1; i < input_keys.size(); i++) + { + if (!(input_keys[i-1] == input_keys[i])) + { + expected_keys.push_back(input_keys[i]); + expected_values.push_back(input_values[i]); + } + } + + // temp storage + size_t temp_storage_size_bytes; + // Get the size of d_temp_storage + HIP_CHECK( + hipcub::DeviceSelect::UniqueByKey( + nullptr, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + d_keys_output, + d_values_output, + d_selected_count_output, + input_keys.size(), + stream, + debug_synchronous + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + void * d_temp_storage = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // run + HIP_CHECK( + hipcub::DeviceSelect::UniqueByKey( + d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + d_keys_output, + d_values_output, + d_selected_count_output, + input_keys.size(), + stream, + debug_synchronous + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + unsigned int selected_count_output = 0; + HIP_CHECK( + hipMemcpy( + &selected_count_output, d_selected_count_output, + sizeof(unsigned int), + hipMemcpyDeviceToHost + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + ASSERT_EQ(selected_count_output, expected_keys.size()); + + // Check if outputs are as expected + std::vector output_keys(input_keys.size()); + + HIP_CHECK( + hipMemcpy( + output_keys.data(), d_keys_output, + output_keys.size() * sizeof(output_keys[0]), + hipMemcpyDeviceToHost + ) + ); + + std::vector output_values(input_values.size()); + + HIP_CHECK( + hipMemcpy( + output_values.data(), d_values_output, + output_values.size() * sizeof(output_values[0]), + hipMemcpyDeviceToHost + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_keys, expected_keys, expected_keys.size())); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_values, expected_values, expected_values.size())); + + hipFree(d_keys_input); + hipFree(d_values_input); + hipFree(d_keys_output); + hipFree(d_values_output); + hipFree(d_selected_count_output); + hipFree(d_temp_storage); + } + } + } +} \ No newline at end of file diff --git a/test/hipcub/test_utils_data_generation.hpp b/test/hipcub/test_utils_data_generation.hpp index 03258b85..868f673c 100644 --- a/test/hipcub/test_utils_data_generation.hpp +++ b/test/hipcub/test_utils_data_generation.hpp @@ -104,7 +104,7 @@ struct special_values { return std::vector(); }else { std::vector r = {test_utils::numeric_limits::quiet_NaN(), - //sign_bit_flip(test_utils::numeric_limits::quiet_NaN()), // TODO: fix AMD issue with -NaN + sign_bit_flip(test_utils::numeric_limits::quiet_NaN()), //test_utils::numeric_limits::signaling_NaN(), // signaling_NaN not supported on NVIDIA yet //sign_bit_flip(test_utils::numeric_limits::signaling_NaN()), test_utils::numeric_limits::infinity(), diff --git a/test/hipcub/test_utils_sort_comparator.hpp b/test/hipcub/test_utils_sort_comparator.hpp index 29c231b9..30b9708c 100644 --- a/test/hipcub/test_utils_sort_comparator.hpp +++ b/test/hipcub/test_utils_sort_comparator.hpp @@ -97,10 +97,22 @@ struct key_comparator struct key_comparator::value || - std::is_same::value>::type> + typename std::enable_if::value>::type> +{ + bool operator()(const Key& lhs, const Key& rhs) + { + test_utils::native_half lhs_native(lhs); + test_utils::native_half rhs_native(rhs); + return key_comparator()(lhs_native, rhs_native); + } +}; + +template +struct key_comparator::value>::type> { bool operator()(const Key& lhs, const Key& rhs) {