From 76c772e6ce279421a957af8557bb20e188e5ba42 Mon Sep 17 00:00:00 2001 From: Karthikeyan <6488848+karthikeyann@users.noreply.github.com> Date: Tue, 22 Mar 2022 19:13:04 +0530 Subject: [PATCH] generate benchmark input in device (#10109) To speedup generate benchmark input generation, move all data generation to device. To address https://github.com/rapidsai/cudf/issues/5773#issuecomment-988153942 This PR moves the random input generation to device. Rest all of the original work in this PR was split to multiple PRs and merged. #10277 #10278 #10279 #10280 #10281 #10300 With all of these changes, single iteration of all benchmark runs in <1000 seconds. (from 3067s to 964s). Running more iterations would see higher benefit too because the benchmark is restarted several times during run which again calls benchmark input generation code. closes https://github.com/rapidsai/cudf/issues/9857 Authors: - Karthikeyan (https://github.com/karthikeyann) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Vukasin Milovanovic (https://github.com/vuule) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/10109 --- cpp/benchmarks/CMakeLists.txt | 6 +- cpp/benchmarks/ast/transform.cpp | 3 +- cpp/benchmarks/column/concatenate.cpp | 126 ++-- .../{generate_input.cpp => generate_input.cu} | 571 ++++++++++-------- cpp/benchmarks/common/generate_input.hpp | 41 +- cpp/benchmarks/common/generate_nullmask.cu | 59 -- .../common/random_distribution_factory.cuh | 179 ++++++ .../common/random_distribution_factory.hpp | 127 ---- cpp/benchmarks/copying/contiguous_split.cu | 84 ++- cpp/benchmarks/copying/gather.cu | 56 +- cpp/benchmarks/copying/shift.cu | 29 +- cpp/benchmarks/filling/repeat.cpp | 62 +- cpp/benchmarks/groupby/group_no_requests.cu | 43 +- cpp/benchmarks/groupby/group_nth.cu | 28 +- cpp/benchmarks/groupby/group_scan.cu | 48 +- cpp/benchmarks/groupby/group_shift.cu | 23 +- cpp/benchmarks/groupby/group_struct.cu | 5 - cpp/benchmarks/groupby/group_sum.cu | 51 +- cpp/benchmarks/hashing/partition.cpp | 26 +- cpp/benchmarks/merge/merge.cpp | 4 +- cpp/benchmarks/quantiles/quantiles.cpp | 41 +- cpp/benchmarks/reduction/anyall.cpp | 22 +- cpp/benchmarks/reduction/dictionary.cpp | 28 +- cpp/benchmarks/reduction/minmax.cpp | 25 +- cpp/benchmarks/reduction/reduce.cpp | 22 +- cpp/benchmarks/search/search.cpp | 37 +- cpp/benchmarks/sort/rank.cpp | 23 +- cpp/benchmarks/sort/sort.cpp | 33 +- .../stream_compaction/apply_boolean_mask.cpp | 54 +- cpp/benchmarks/stream_compaction/distinct.cpp | 27 +- cpp/benchmarks/stream_compaction/unique.cpp | 27 +- cpp/benchmarks/string/extract.cpp | 19 +- cpp/benchmarks/string/json.cu | 10 +- .../transpose/{transpose.cu => transpose.cpp} | 31 +- 34 files changed, 918 insertions(+), 1052 deletions(-) rename cpp/benchmarks/common/{generate_input.cpp => generate_input.cu} (51%) delete mode 100644 cpp/benchmarks/common/generate_nullmask.cu create mode 100644 cpp/benchmarks/common/random_distribution_factory.cuh delete mode 100644 cpp/benchmarks/common/random_distribution_factory.hpp rename cpp/benchmarks/transpose/{transpose.cu => transpose.cpp} (67%) diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 47327a505f1..9e8a632a3ae 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -14,7 +14,7 @@ find_package(Threads REQUIRED) -add_library(cudf_datagen STATIC common/generate_input.cpp common/generate_nullmask.cu) +add_library(cudf_datagen STATIC common/generate_input.cu) target_compile_features(cudf_datagen PUBLIC cxx_std_17 cuda_std_17) target_compile_options( @@ -136,7 +136,7 @@ ConfigureBench(COPY_IF_ELSE_BENCH copying/copy_if_else.cpp) # ################################################################################################## # * transpose benchmark --------------------------------------------------------------------------- -ConfigureBench(TRANSPOSE_BENCH transpose/transpose.cu) +ConfigureBench(TRANSPOSE_BENCH transpose/transpose.cpp) # ################################################################################################## # * apply_boolean_mask benchmark ------------------------------------------------------------------ @@ -145,7 +145,7 @@ ConfigureBench(APPLY_BOOLEAN_MASK_BENCH stream_compaction/apply_boolean_mask.cpp # ################################################################################################## # * stream_compaction benchmark ------------------------------------------------------------------- ConfigureNVBench( - STREAM_COMPACTION_BENCH stream_compaction/distinct.cpp stream_compaction/unique.cpp + STREAM_COMPACTION_NVBENCH stream_compaction/distinct.cpp stream_compaction/unique.cpp ) # ################################################################################################## diff --git a/cpp/benchmarks/ast/transform.cpp b/cpp/benchmarks/ast/transform.cpp index de0429f74ad..20e64e0a90d 100644 --- a/cpp/benchmarks/ast/transform.cpp +++ b/cpp/benchmarks/ast/transform.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include enum class TreeType { @@ -48,7 +49,7 @@ static void BM_ast_transform(benchmark::State& state) auto const source_table = create_sequence_table(cycle_dtypes({cudf::type_to_id()}, n_cols), row_count{table_size}, - Nullable ? 0.5 : -1.0); + Nullable ? std::optional{0.5} : std::nullopt); auto table = source_table->view(); // Create column references diff --git a/cpp/benchmarks/column/concatenate.cpp b/cpp/benchmarks/column/concatenate.cpp index abca4b4e0f5..89f5fcb27a6 100644 --- a/cpp/benchmarks/column/concatenate.cpp +++ b/cpp/benchmarks/column/concatenate.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include - -#include - +#include #include #include #include +#include + +#include +#include + #include #include @@ -33,31 +34,14 @@ class Concatenate : public cudf::benchmark { template static void BM_concatenate(benchmark::State& state) { - using column_wrapper = cudf::test::fixed_width_column_wrapper; - - auto const num_rows = state.range(0); - auto const num_cols = state.range(1); - - // Create owning columns - std::vector columns; - columns.reserve(num_cols); - std::generate_n(std::back_inserter(columns), num_cols, [num_rows]() { - auto iter = thrust::make_counting_iterator(0); - if (Nullable) { - auto valid_iter = thrust::make_transform_iterator(iter, [](auto i) { return i % 3 == 0; }); - return column_wrapper(iter, iter + num_rows, valid_iter); - } else { - return column_wrapper(iter, iter + num_rows); - } - }); + cudf::size_type const num_rows = state.range(0); + cudf::size_type const num_cols = state.range(1); - // Generate column views - std::vector column_views; - column_views.reserve(columns.size()); - std::transform( - columns.begin(), columns.end(), std::back_inserter(column_views), [](auto const& col) { - return static_cast(col); - }); + auto input = create_sequence_table(cycle_dtypes({cudf::type_to_id()}, num_cols), + row_count{num_rows}, + Nullable ? std::optional{2.0 / 3.0} : std::nullopt); + auto input_columns = input->view(); + std::vector column_views(input_columns.begin(), input_columns.end()); CHECK_CUDA(0); @@ -69,11 +53,13 @@ static void BM_concatenate(benchmark::State& state) state.SetBytesProcessed(state.iterations() * num_cols * num_rows * sizeof(T)); } -#define CONCAT_BENCHMARK_DEFINE(type, nullable) \ - TEMPLATED_BENCHMARK_F(Concatenate, BM_concatenate, type, nullable) \ - ->RangeMultiplier(8) \ - ->Ranges({{1 << 6, 1 << 18}, {2, 1024}}) \ - ->Unit(benchmark::kMillisecond) \ +#define CONCAT_BENCHMARK_DEFINE(type, nullable) \ + BENCHMARK_DEFINE_F(Concatenate, BM_concatenate##_##nullable_##nullable) \ + (::benchmark::State & st) { BM_concatenate(st); } \ + BENCHMARK_REGISTER_F(Concatenate, BM_concatenate##_##nullable_##nullable) \ + ->RangeMultiplier(8) \ + ->Ranges({{1 << 6, 1 << 18}, {2, 1024}}) \ + ->Unit(benchmark::kMillisecond) \ ->UseManualTime(); CONCAT_BENCHMARK_DEFINE(int64_t, false) @@ -82,42 +68,22 @@ CONCAT_BENCHMARK_DEFINE(int64_t, true) template static void BM_concatenate_tables(benchmark::State& state) { - using column_wrapper = cudf::test::fixed_width_column_wrapper; - - auto const num_rows = state.range(0); - auto const num_cols = state.range(1); - auto const num_tables = state.range(2); - - // Create owning columns - std::vector columns; - columns.reserve(num_cols); - std::generate_n(std::back_inserter(columns), num_cols * num_tables, [num_rows]() { - auto iter = thrust::make_counting_iterator(0); - if (Nullable) { - auto valid_iter = thrust::make_transform_iterator(iter, [](auto i) { return i % 3 == 0; }); - return column_wrapper(iter, iter + num_rows, valid_iter); - } else { - return column_wrapper(iter, iter + num_rows); - } + cudf::size_type const num_rows = state.range(0); + cudf::size_type const num_cols = state.range(1); + cudf::size_type const num_tables = state.range(2); + + std::vector> tables(num_tables); + std::generate_n(tables.begin(), num_tables, [&]() { + return create_sequence_table(cycle_dtypes({cudf::type_to_id()}, num_cols), + row_count{num_rows}, + Nullable ? std::optional{2.0 / 3.0} : std::nullopt); }); - // Generate column views - std::vector> column_views(num_tables); - for (int i = 0; i < num_tables; ++i) { - column_views[i].reserve(num_cols); - auto it = columns.begin() + (i * num_cols); - std::transform(it, it + num_cols, std::back_inserter(column_views[i]), [](auto const& col) { - return static_cast(col); - }); - } - // Generate table views - std::vector table_views; - table_views.reserve(num_tables); - std::transform(column_views.begin(), - column_views.end(), - std::back_inserter(table_views), - [](auto const& col_vec) { return cudf::table_view(col_vec); }); + std::vector table_views(num_tables); + std::transform(tables.begin(), tables.end(), table_views.begin(), [](auto& table) mutable { + return table->view(); + }); CHECK_CUDA(0); @@ -129,11 +95,13 @@ static void BM_concatenate_tables(benchmark::State& state) state.SetBytesProcessed(state.iterations() * num_cols * num_rows * num_tables * sizeof(T)); } -#define CONCAT_TABLES_BENCHMARK_DEFINE(type, nullable) \ - TEMPLATED_BENCHMARK_F(Concatenate, BM_concatenate_tables, type, nullable) \ - ->RangeMultiplier(8) \ - ->Ranges({{1 << 8, 1 << 12}, {2, 32}, {2, 128}}) \ - ->Unit(benchmark::kMillisecond) \ +#define CONCAT_TABLES_BENCHMARK_DEFINE(type, nullable) \ + BENCHMARK_DEFINE_F(Concatenate, BM_concatenate_tables##_##nullable_##nullable) \ + (::benchmark::State & st) { BM_concatenate_tables(st); } \ + BENCHMARK_REGISTER_F(Concatenate, BM_concatenate_tables##_##nullable_##nullable) \ + ->RangeMultiplier(8) \ + ->Ranges({{1 << 8, 1 << 12}, {2, 32}, {2, 128}}) \ + ->Unit(benchmark::kMillisecond) \ ->UseManualTime(); CONCAT_TABLES_BENCHMARK_DEFINE(int64_t, false) @@ -187,11 +155,13 @@ static void BM_concatenate_strings(benchmark::State& state) (sizeof(int32_t) + num_chars)); // offset + chars } -#define CONCAT_STRINGS_BENCHMARK_DEFINE(nullable) \ - TEMPLATED_BENCHMARK_F(ConcatenateStrings, BM_concatenate_strings, nullable) \ - ->RangeMultiplier(8) \ - ->Ranges({{1 << 8, 1 << 14}, {8, 128}, {2, 256}}) \ - ->Unit(benchmark::kMillisecond) \ +#define CONCAT_STRINGS_BENCHMARK_DEFINE(nullable) \ + BENCHMARK_DEFINE_F(Concatenate, BM_concatenate_strings##_##nullable_##nullable) \ + (::benchmark::State & st) { BM_concatenate_strings(st); } \ + BENCHMARK_REGISTER_F(Concatenate, BM_concatenate_strings##_##nullable_##nullable) \ + ->RangeMultiplier(8) \ + ->Ranges({{1 << 8, 1 << 14}, {8, 128}, {2, 256}}) \ + ->Unit(benchmark::kMillisecond) \ ->UseManualTime(); CONCAT_STRINGS_BENCHMARK_DEFINE(false) diff --git a/cpp/benchmarks/common/generate_input.cpp b/cpp/benchmarks/common/generate_input.cu similarity index 51% rename from cpp/benchmarks/common/generate_input.cpp rename to cpp/benchmarks/common/generate_input.cu index 6330beda54c..460483e37a4 100644 --- a/cpp/benchmarks/common/generate_input.cpp +++ b/cpp/benchmarks/common/generate_input.cu @@ -15,33 +15,50 @@ */ #include "generate_input.hpp" -#include "random_distribution_factory.hpp" +#include "random_distribution_factory.cuh" #include -#include +#include +#include +#include #include +#include #include #include -#include - -#include -#include +#include +#include #include #include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include #include #include #include -#include +#include #include /** * @brief Mersenne Twister pseudo-random engine. */ -auto deterministic_engine(unsigned seed) { return std::mt19937{seed}; } +auto deterministic_engine(unsigned seed) { return thrust::minstd_rand{seed}; } /** * Computes the mean value for a distribution of given type and value bounds. @@ -109,6 +126,29 @@ size_t avg_element_size(data_profile const& profile, cudf::data_type dtype) return cudf::type_dispatcher(dtype, non_fixed_width_size_fn{}, profile); } +/** + * @brief bool generator with given probability [0.0 - 1.0] of returning true. + */ +struct bool_generator { + thrust::minstd_rand engine; + thrust::uniform_real_distribution dist; + double probability_true; + bool_generator(thrust::minstd_rand engine, double probability_true) + : engine(engine), dist{0, 1}, probability_true{probability_true} + { + } + bool_generator(unsigned seed, double probability_true) + : engine(seed), dist{0, 1}, probability_true{probability_true} + { + } + + __device__ bool operator()(size_t n) + { + engine.discard(n); + return dist(engine) < probability_true; + } +}; + /** * @brief Functor that computes a random column element with the given data profile. * @@ -123,8 +163,8 @@ struct random_value_fn; */ template struct random_value_fn()>> { - std::function seconds_gen; - std::function nanoseconds_gen; + distribution_fn seconds_gen; + distribution_fn nanoseconds_gen; random_value_fn(distribution_params params) { @@ -140,7 +180,11 @@ struct random_value_fn()>> { nanoseconds_gen = make_distribution(distribution_id::UNIFORM, 0l, 1000000000l); } else { // Don't need a random seconds generator for sub-second intervals - seconds_gen = [=](std::mt19937&) { return range_s.second.count(); }; + seconds_gen = [range_s](thrust::minstd_rand&, size_t size) { + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::fill(thrust::device, result.begin(), result.end(), range_s.second.count()); + return result; + }; std::pair const range_ns = { duration_cast(typename T::duration{params.lower_bound}), @@ -151,17 +195,29 @@ struct random_value_fn()>> { } } - T operator()(std::mt19937& engine) + rmm::device_uvector operator()(thrust::minstd_rand& engine, unsigned size) { - auto const timestamp_ns = - cudf::duration_s{seconds_gen(engine)} + cudf::duration_ns{nanoseconds_gen(engine)}; - // Return value in the type's precision - return T(cuda::std::chrono::duration_cast(timestamp_ns)); + auto const sec = seconds_gen(engine, size); + auto const ns = nanoseconds_gen(engine, size); + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::transform( + thrust::device, + sec.begin(), + sec.end(), + ns.begin(), + result.begin(), + [] __device__(int64_t sec_value, int64_t nanoseconds_value) { + auto const timestamp_ns = + cudf::duration_s{sec_value} + cudf::duration_ns{nanoseconds_value}; + // Return value in the type's precision + return T(cuda::std::chrono::duration_cast(timestamp_ns)); + }); + return result; } }; /** - * @brief Creates an random fixed_point value. Not implemented yet. + * @brief Creates an random fixed_point value. */ template struct random_value_fn()>> { @@ -178,15 +234,27 @@ struct random_value_fn()>> { { } - T operator()(std::mt19937& engine) + rmm::device_uvector operator()(thrust::minstd_rand& engine, unsigned size) { if (not scale.has_value()) { int const max_scale = std::numeric_limits::digits10; - auto scale_dist = make_distribution(distribution_id::NORMAL, -max_scale, max_scale); - scale = numeric::scale_type{std::max(std::min(scale_dist(engine), max_scale), -max_scale)}; + std::uniform_int_distribution scale_dist{-max_scale, max_scale}; + std::mt19937 engine_scale(engine()); + scale = numeric::scale_type{scale_dist(engine_scale)}; } + auto const ints = dist(engine, size); + rmm::device_uvector result(size, rmm::cuda_stream_default); // Clamp the generated random value to the specified range - return T{std::max(std::min(dist(engine), upper_bound), lower_bound), *scale}; + thrust::transform(thrust::device, + ints.begin(), + ints.end(), + result.begin(), + [scale = *(this->scale), + upper_bound = this->upper_bound, + lower_bound = this->lower_bound] __device__(auto int_value) { + return T{std::clamp(int_value, lower_bound, upper_bound), scale}; + }); + return result; } }; @@ -206,42 +274,29 @@ struct random_value_fn && cudf::is_ { } - T operator()(std::mt19937& engine) - { - // Clamp the generated random value to the specified range - return std::max(std::min(dist(engine), upper_bound), lower_bound); - } + auto operator()(thrust::minstd_rand& engine, unsigned size) { return dist(engine, size); } }; /** * @brief Creates an boolean value with given probability of returning `true`. */ template -struct random_value_fn>> { - std::bernoulli_distribution b_dist; - - random_value_fn(distribution_params const& desc) : b_dist{desc.probability_true} {} - bool operator()(std::mt19937& engine) { return b_dist(engine); } -}; - -size_t null_mask_size(cudf::size_type num_rows) -{ - constexpr size_t bitmask_bits = cudf::detail::size_in_bits(); - return (num_rows + bitmask_bits - 1) / bitmask_bits; -} -template -void set_element_at(T value, - bool valid, - std::vector& values, - std::vector& null_mask, - cudf::size_type idx) -{ - if (valid) { - values[idx] = value; - } else { - cudf::clear_bit_unsafe(null_mask.data(), idx); +struct random_value_fn>> { + // Bernoulli distribution + distribution_fn dist; + + random_value_fn(distribution_params const& desc) + : dist{[valid_prob = desc.probability_true](thrust::minstd_rand& engine, + size_t size) -> rmm::device_uvector { + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::tabulate( + thrust::device, result.begin(), result.end(), bool_generator(engine, valid_prob)); + return result; + }} + { } -} + auto operator()(thrust::minstd_rand& engine, unsigned size) { return dist(engine, size); } +}; auto create_run_length_dist(cudf::size_type avg_run_len) { @@ -250,17 +305,52 @@ auto create_run_length_dist(cudf::size_type avg_run_len) return std::gamma_distribution{alpha, avg_run_len / alpha}; } -// identity mapping, except for bools -template -struct stored_as { - using type = T; -}; - -// Use `int8_t` for bools because that's how they're stored in columns -template -struct stored_as>> { - using type = int8_t; -}; +/** + * @brief Generate indices within range [0 , cardinality) repeating with average run length + * `avg_run_len` + * + * @param avg_run_len Average run length of the generated indices + * @param cardinality Number of unique values in the output vector + * @param num_rows Number of indices to generate + * @param engine Random engine + * @return Generated indices of type `cudf::size_type` + */ +rmm::device_uvector sample_indices_with_run_length(cudf::size_type avg_run_len, + cudf::size_type cardinality, + cudf::size_type num_rows, + thrust::minstd_rand& engine) +{ + auto sample_dist = random_value_fn{ + distribution_params{distribution_id::UNIFORM, 0, cardinality - 1}}; + if (avg_run_len > 1) { + auto avglen_dist = + random_value_fn{distribution_params{distribution_id::UNIFORM, 1, 2 * avg_run_len}}; + auto const approx_run_len = num_rows / avg_run_len + 1; + auto run_lens = avglen_dist(engine, approx_run_len); + thrust::inclusive_scan( + thrust::device, run_lens.begin(), run_lens.end(), run_lens.begin(), std::plus{}); + auto const samples_indices = sample_dist(engine, approx_run_len + 1); + // This is gather. + auto avg_repeated_sample_indices_iterator = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + [rb = run_lens.begin(), + re = run_lens.end(), + samples_indices = samples_indices.begin()] __device__(cudf::size_type i) { + auto sample_idx = thrust::upper_bound(thrust::seq, rb, re, i) - rb; + return samples_indices[sample_idx]; + }); + rmm::device_uvector repeated_sample_indices(num_rows, + rmm::cuda_stream_default); + thrust::copy(thrust::device, + avg_repeated_sample_indices_iterator, + avg_repeated_sample_indices_iterator + num_rows, + repeated_sample_indices.begin()); + return repeated_sample_indices; + } else { + // generate n samples. + return sample_dist(engine, num_rows); + } +} /** * @brief Creates a column with random content of type @ref T. @@ -274,128 +364,127 @@ struct stored_as>> { */ template std::unique_ptr create_random_column(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { - // Working around vector and storing bools as int8_t - using stored_Type = typename stored_as::type; - - auto valid_dist = std::bernoulli_distribution{1. - profile.get_null_frequency()}; + // Bernoulli distribution + auto valid_dist = + random_value_fn(distribution_params{1. - profile.get_null_frequency().value_or(0)}); auto value_dist = random_value_fn{profile.get_distribution_params()}; - auto const cardinality = std::min(num_rows, profile.get_cardinality()); - std::vector samples(cardinality); - std::vector samples_null_mask(null_mask_size(cardinality), ~0); - for (cudf::size_type si = 0; si < cardinality; ++si) { - set_element_at( - (stored_Type)value_dist(engine), valid_dist(engine), samples, samples_null_mask, si); - } + auto const cardinality = std::min(num_rows, profile.get_cardinality()); + rmm::device_uvector samples_null_mask = valid_dist(engine, cardinality); + rmm::device_uvector samples = value_dist(engine, cardinality); // Distribution for picking elements from the array of samples - std::uniform_int_distribution sample_dist{0, cardinality - 1}; auto const avg_run_len = profile.get_avg_run_length(); - auto run_len_dist = create_run_length_dist(avg_run_len); - std::vector data(num_rows); - std::vector null_mask(null_mask_size(num_rows), ~0); + rmm::device_uvector data(0, rmm::cuda_stream_default); + rmm::device_uvector null_mask(0, rmm::cuda_stream_default); - for (cudf::size_type row = 0; row < num_rows; ++row) { - if (cardinality == 0) { - set_element_at((stored_Type)value_dist(engine), valid_dist(engine), data, null_mask, row); - } else { - auto const sample_idx = sample_dist(engine); - set_element_at(samples[sample_idx], - cudf::bit_is_set(samples_null_mask.data(), sample_idx), - data, - null_mask, - row); - } - - if (avg_run_len > 1) { - int const run_len = std::min(num_rows - row, std::round(run_len_dist(engine))); - for (int offset = 1; offset < run_len; ++offset) { - set_element_at( - data[row], cudf::bit_is_set(null_mask.data(), row), data, null_mask, row + offset); - } - row += std::max(run_len - 1, 0); - } + if (cardinality == 0) { + data = value_dist(engine, num_rows); + null_mask = valid_dist(engine, num_rows); + } else { + // generate n samples and gather. + auto const sample_indices = + sample_indices_with_run_length(avg_run_len, cardinality, num_rows, engine); + data = rmm::device_uvector(num_rows, rmm::cuda_stream_default); + null_mask = rmm::device_uvector(num_rows, rmm::cuda_stream_default); + thrust::gather( + thrust::device, sample_indices.begin(), sample_indices.end(), samples.begin(), data.begin()); + thrust::gather(thrust::device, + sample_indices.begin(), + sample_indices.end(), + samples_null_mask.begin(), + null_mask.begin()); } - // cudf expects the null mask buffer to be padded up to 64 bytes. so allocate the proper size and - // copy what we have. - rmm::device_buffer result_bitmask{cudf::bitmask_allocation_size_bytes(num_rows), - rmm::cuda_stream_default}; - cudaMemcpyAsync(result_bitmask.data(), - null_mask.data(), - null_mask.size() * sizeof(cudf::bitmask_type), - cudaMemcpyHostToDevice, - rmm::cuda_stream_default); + auto [result_bitmask, null_count] = + cudf::detail::valid_if(null_mask.begin(), null_mask.end(), thrust::identity{}); return std::make_unique( cudf::data_type{cudf::type_to_id()}, num_rows, - rmm::device_buffer(data.data(), num_rows * sizeof(stored_Type), rmm::cuda_stream_default), - std::move(result_bitmask)); + data.release(), + profile.get_null_frequency().has_value() ? std::move(result_bitmask) : rmm::device_buffer{}); } -/** - * @brief Class that holds string column data in host memory. - */ -struct string_column_data { - std::vector chars; - std::vector offsets; - std::vector null_mask; - explicit string_column_data(cudf::size_type rows, cudf::size_type size) +struct valid_or_zero { + template + __device__ T operator()(thrust::tuple len_valid) const { - offsets.reserve(rows + 1); - offsets.push_back(0); - chars.reserve(size); - null_mask.insert(null_mask.end(), null_mask_size(rows), ~0); + return thrust::get<1>(len_valid) ? thrust::get<0>(len_valid) : T{0}; } }; -/** - * @brief Copy a string from one host-side "column" to another. - * - * Assumes that the destination null mask is initialized with all bits valid. - */ -void copy_string(cudf::size_type src_idx, - string_column_data const& src, - cudf::size_type dst_idx, - string_column_data& dst) -{ - if (!cudf::bit_is_set(src.null_mask.data(), src_idx)) - cudf::clear_bit_unsafe(dst.null_mask.data(), dst_idx); - auto const str_len = src.offsets[src_idx + 1] - src.offsets[src_idx]; - dst.chars.resize(dst.chars.size() + str_len); - if (cudf::bit_is_set(src.null_mask.data(), src_idx)) { - std::copy_n( - src.chars.begin() + src.offsets[src_idx], str_len, dst.chars.begin() + dst.offsets.back()); +struct string_generator { + char* chars; + thrust::minstd_rand engine; + thrust::uniform_int_distribution char_dist; + string_generator(char* c, thrust::minstd_rand& engine) + : chars(c), engine(engine), char_dist(32, 137) + // ~90% ASCII, ~10% UTF-8. + // ~80% not-space, ~20% space. + // range 32-127 is ASCII; 127-136 will be multi-byte UTF-8 + { } - dst.offsets.push_back(dst.chars.size()); -} + __device__ void operator()(thrust::tuple str_begin_end) + { + auto begin = thrust::get<0>(str_begin_end); + auto end = thrust::get<1>(str_begin_end); + engine.discard(begin); + for (auto i = begin; i < end; ++i) { + auto ch = char_dist(engine); + if (i == end - 1 && ch >= '\x7F') ch = ' '; // last element ASCII only. + if (ch >= '\x7F') // x7F is at the top edge of ASCII + chars[i++] = '\xC4'; // these characters are assigned two bytes + chars[i] = static_cast(ch + (ch >= '\x7F')); + } + } +}; /** - * @brief Generate a random string at the end of the host-side "column". + * @brief Create a UTF-8 string column with the average length. * - * Assumes that the destination null mask is initialized with all bits valid. */ -template -void append_string(Char_gen& char_gen, bool valid, uint32_t length, string_column_data& column_data) +std::unique_ptr create_random_utf8_string_column(data_profile const& profile, + thrust::minstd_rand& engine, + cudf::size_type num_rows) { - if (!valid) { - auto const idx = column_data.offsets.size() - 1; - cudf::clear_bit_unsafe(column_data.null_mask.data(), idx); - // duplicate the offset value to indicate an empty row - column_data.offsets.push_back(column_data.offsets.back()); - return; - } - for (uint32_t idx = 0; idx < length; ++idx) { - auto const ch = char_gen(); - if (ch >= '\x7F') // x7F is at the top edge of ASCII - column_data.chars.push_back('\xC4'); // these characters are assigned two bytes - column_data.chars.push_back(static_cast(ch + (ch >= '\x7F'))); - } - column_data.offsets.push_back(column_data.chars.size()); + auto len_dist = + random_value_fn{profile.get_distribution_params().length_params}; + auto valid_dist = + random_value_fn(distribution_params{1. - profile.get_null_frequency().value_or(0)}); + auto lengths = len_dist(engine, num_rows + 1); + auto null_mask = valid_dist(engine, num_rows + 1); + thrust::transform_if( + thrust::device, + lengths.begin(), + lengths.end(), + null_mask.begin(), + lengths.begin(), + [] __device__(auto) { return 0; }, + thrust::logical_not{}); + auto valid_lengths = thrust::make_transform_iterator( + thrust::make_zip_iterator(thrust::make_tuple(lengths.begin(), null_mask.begin())), + valid_or_zero{}); + rmm::device_uvector offsets(num_rows + 1, rmm::cuda_stream_default); + thrust::exclusive_scan( + thrust::device, valid_lengths, valid_lengths + lengths.size(), offsets.begin()); + // offfsets are ready. + auto chars_length = *thrust::device_pointer_cast(offsets.end() - 1); + rmm::device_uvector chars(chars_length, rmm::cuda_stream_default); + thrust::for_each_n(thrust::device, + thrust::make_zip_iterator(offsets.begin(), offsets.begin() + 1), + num_rows, + string_generator{chars.data(), engine}); + auto [result_bitmask, null_count] = + cudf::detail::valid_if(null_mask.begin(), null_mask.end() - 1, thrust::identity{}); + return cudf::make_strings_column( + num_rows, + std::move(offsets), + std::move(chars), + profile.get_null_frequency().has_value() ? std::move(result_bitmask) : rmm::device_buffer{}); } /** @@ -409,53 +498,26 @@ void append_string(Char_gen& char_gen, bool valid, uint32_t length, string_colum */ template <> std::unique_ptr create_random_column(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { - auto char_dist = [&engine, // range 32-127 is ASCII; 127-136 will be multi-byte UTF-8 - dist = std::uniform_int_distribution{32, 137}]() mutable { - return dist(engine); - }; - auto len_dist = - random_value_fn{profile.get_distribution_params().length_params}; - auto valid_dist = std::bernoulli_distribution{1. - profile.get_null_frequency()}; - - auto const avg_string_len = non_fixed_width_size(profile); - auto const cardinality = std::min(profile.get_cardinality(), num_rows); - string_column_data samples(cardinality, cardinality * avg_string_len); - for (cudf::size_type si = 0; si < cardinality; ++si) { - append_string(char_dist, valid_dist(engine), len_dist(engine), samples); - } - + auto const cardinality = std::min(profile.get_cardinality(), num_rows); auto const avg_run_len = profile.get_avg_run_length(); - auto run_len_dist = create_run_length_dist(avg_run_len); - - string_column_data out_col(num_rows, num_rows * avg_string_len); - std::uniform_int_distribution sample_dist{0, cardinality - 1}; - for (cudf::size_type row = 0; row < num_rows; ++row) { - if (cardinality == 0) { - append_string(char_dist, valid_dist(engine), len_dist(engine), out_col); - } else { - copy_string(sample_dist(engine), samples, row, out_col); - } - if (avg_run_len > 1) { - int const run_len = std::min(num_rows - row, std::round(run_len_dist(engine))); - for (int offset = 1; offset < run_len; ++offset) { - copy_string(row, out_col, row + offset, out_col); - } - row += std::max(run_len - 1, 0); - } - } - auto d_chars = cudf::detail::make_device_uvector_sync(out_col.chars); - auto d_offsets = cudf::detail::make_device_uvector_sync(out_col.offsets); - auto d_null_mask = cudf::detail::make_device_uvector_sync(out_col.null_mask); - return cudf::make_strings_column(d_chars, d_offsets, d_null_mask); + auto sample_strings = + create_random_utf8_string_column(profile, engine, cardinality == 0 ? num_rows : cardinality); + if (cardinality == 0) { return sample_strings; } + auto sample_indices = sample_indices_with_run_length(avg_run_len, cardinality, num_rows, engine); + auto str_table = cudf::detail::gather(cudf::table_view{{sample_strings->view()}}, + sample_indices, + cudf::out_of_bounds_policy::DONT_CHECK, + cudf::detail::negative_index_policy::NOT_ALLOWED); + return std::move(str_table->release()[0]); } template <> std::unique_ptr create_random_column(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { CUDF_FAIL("not implemented yet"); @@ -463,7 +525,7 @@ std::unique_ptr create_random_column(data_prof template <> std::unique_ptr create_random_column(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { CUDF_FAIL("not implemented yet"); @@ -476,13 +538,19 @@ struct create_rand_col_fn { public: template std::unique_ptr operator()(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { return create_random_column(profile, engine, num_rows); } }; +template +struct clamp_down : public thrust::unary_function { + T max; + clamp_down(T max) : max(max) {} + __host__ __device__ T operator()(T x) const { return min(x, max); } +}; /** * @brief Creates a list column with random content. * @@ -497,7 +565,7 @@ struct create_rand_col_fn { */ template <> std::unique_ptr create_random_column(data_profile const& profile, - std::mt19937& engine, + thrust::minstd_rand& engine, cudf::size_type num_rows) { auto const dist_params = profile.get_distribution_params(); @@ -508,7 +576,8 @@ std::unique_ptr create_random_column(data_profile cudf::data_type(dist_params.element_type), create_rand_col_fn{}, profile, engine, num_elements); auto len_dist = random_value_fn{profile.get_distribution_params().length_params}; - auto valid_dist = std::bernoulli_distribution{1. - profile.get_null_frequency()}; + auto valid_dist = + random_value_fn(distribution_params{1. - profile.get_null_frequency().value_or(0)}); // Generate the list column bottom-up auto list_column = std::move(leaf_column); @@ -517,29 +586,27 @@ std::unique_ptr create_random_column(data_profile auto current_child_column = std::move(list_column); cudf::size_type const num_rows = current_child_column->size() / single_level_mean; - std::vector offsets{0}; - offsets.reserve(num_rows + 1); - std::vector null_mask(null_mask_size(num_rows), ~0); - for (cudf::size_type row = 1; row < num_rows + 1; ++row) { - offsets.push_back( - std::min(current_child_column->size(), offsets.back() + len_dist(engine))); - if (!valid_dist(engine)) cudf::clear_bit_unsafe(null_mask.data(), row); - } - offsets.back() = current_child_column->size(); // Always include all elements + auto offsets = len_dist(engine, num_rows + 1); + auto valids = valid_dist(engine, num_rows); + // to ensure these values <= current_child_column->size() + auto output_offsets = thrust::make_transform_output_iterator( + offsets.begin(), clamp_down{current_child_column->size()}); + + thrust::exclusive_scan(thrust::device, offsets.begin(), offsets.end(), output_offsets); + thrust::device_pointer_cast(offsets.end())[-1] = + current_child_column->size(); // Always include all elements auto offsets_column = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, - offsets.size(), - rmm::device_buffer( - offsets.data(), offsets.size() * sizeof(int32_t), rmm::cuda_stream_default)); + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release()); + auto [null_mask, null_count] = + cudf::detail::valid_if(valids.begin(), valids.end(), thrust::identity{}); list_column = cudf::make_lists_column( num_rows, std::move(offsets_column), std::move(current_child_column), - cudf::UNKNOWN_NULL_COUNT, - rmm::device_buffer( - null_mask.data(), null_mask.size() * sizeof(cudf::bitmask_type), rmm::cuda_stream_default)); + profile.get_null_frequency().has_value() ? null_count : 0, // cudf::UNKNOWN_NULL_COUNT, + profile.get_null_frequency().has_value() ? std::move(null_mask) : rmm::device_buffer{}); } return list_column; // return the top-level column } @@ -558,12 +625,13 @@ using columns_vector = std::vector>; */ columns_vector create_random_columns(data_profile const& profile, std::vector dtype_ids, - std::mt19937 engine, + thrust::minstd_rand engine, cudf::size_type num_rows) { columns_vector output_columns; std::transform( dtype_ids.begin(), dtype_ids.end(), std::back_inserter(output_columns), [&](auto tid) { + engine.discard(num_rows); return cudf::type_dispatcher( cudf::data_type(tid), create_rand_col_fn{}, profile, engine, num_rows); }); @@ -604,57 +672,56 @@ std::unique_ptr create_random_table(std::vector cons data_profile const& profile, unsigned seed) { - cudf::size_type const num_cols = dtype_ids.size(); - auto seed_engine = deterministic_engine(seed); - - auto const processor_count = std::thread::hardware_concurrency(); - cudf::size_type const cols_per_thread = (num_cols + processor_count - 1) / processor_count; - cudf::size_type next_col = 0; - std::vector> col_futures; - random_value_fn seed_dist( - {distribution_id::UNIFORM, 0, std::numeric_limits::max()}); - for (unsigned int i = 0; i < processor_count && next_col < num_cols; ++i) { - auto thread_engine = deterministic_engine(seed_dist(seed_engine)); - auto const thread_num_cols = std::min(num_cols - next_col, cols_per_thread); - std::vector thread_types(dtype_ids.begin() + next_col, - dtype_ids.begin() + next_col + thread_num_cols); - col_futures.emplace_back(std::async(std::launch::async, - create_random_columns, - std::cref(profile), - std::move(thread_types), - std::move(thread_engine), - num_rows.count)); - next_col += thread_num_cols; - } + auto seed_engine = deterministic_engine(seed); + thrust::uniform_int_distribution seed_dist; columns_vector output_columns; - for (auto& cf : col_futures) { - auto partial_table = cf.get(); - output_columns.reserve(output_columns.size() + partial_table.size()); - std::move( - std::begin(partial_table), std::end(partial_table), std::back_inserter(output_columns)); - partial_table.clear(); - } - + std::transform( + dtype_ids.begin(), dtype_ids.end(), std::back_inserter(output_columns), [&](auto tid) mutable { + auto engine = deterministic_engine(seed_dist(seed_engine)); + return cudf::type_dispatcher( + cudf::data_type(tid), create_rand_col_fn{}, profile, engine, num_rows.count); + }); return std::make_unique(std::move(output_columns)); } std::unique_ptr create_sequence_table(std::vector const& dtype_ids, row_count num_rows, - float null_probability, + std::optional null_probability, unsigned seed) { + auto seed_engine = deterministic_engine(seed); + thrust::uniform_int_distribution seed_dist; + auto columns = std::vector>(dtype_ids.size()); std::transform(dtype_ids.begin(), dtype_ids.end(), columns.begin(), [&](auto dtype) mutable { - auto init = cudf::make_default_constructed_scalar(cudf::data_type{dtype}); - auto col = cudf::sequence(num_rows.count, *init); - auto [mask, count] = create_random_null_mask(num_rows.count, null_probability, seed++); + auto init = cudf::make_default_constructed_scalar(cudf::data_type{dtype}); + auto col = cudf::sequence(num_rows.count, *init); + auto [mask, count] = + create_random_null_mask(num_rows.count, null_probability, seed_dist(seed_engine)); col->set_null_mask(std::move(mask), count); return col; }); return std::make_unique(std::move(columns)); } +std::pair create_random_null_mask( + cudf::size_type size, std::optional null_probability, unsigned seed) +{ + if (not null_probability.has_value()) { return {rmm::device_buffer{}, 0}; } + CUDF_EXPECTS(*null_probability >= 0.0 and *null_probability <= 1.0, + "Null probability must be within the range [0.0, 1.0]"); + if (*null_probability == 0.0f) { + return {cudf::create_null_mask(size, cudf::mask_state::ALL_VALID), 0}; + } else if (*null_probability == 1.0) { + return {cudf::create_null_mask(size, cudf::mask_state::ALL_NULL), size}; + } else { + return cudf::detail::valid_if(thrust::make_counting_iterator(0), + thrust::make_counting_iterator(size), + bool_generator{seed, 1.0 - *null_probability}); + } +} + std::vector get_type_or_group(int32_t id) { // identity transformation when passing a concrete type_id diff --git a/cpp/benchmarks/common/generate_input.hpp b/cpp/benchmarks/common/generate_input.hpp index 5246de00a73..c955f60f97e 100644 --- a/cpp/benchmarks/common/generate_input.hpp +++ b/cpp/benchmarks/common/generate_input.hpp @@ -19,7 +19,6 @@ #include #include -#include #include /** @@ -217,10 +216,10 @@ class data_profile { cudf::type_id::INT32, {distribution_id::GEOMETRIC, 0, 100}, 2}; std::map> decimal_params; - double bool_probability = 0.5; - double null_frequency = 0.01; - cudf::size_type cardinality = 2000; - cudf::size_type avg_run_length = 4; + double bool_probability = 0.5; + std::optional null_frequency = 0.01; + cudf::size_type cardinality = 2000; + cudf::size_type avg_run_length = 4; public: template (), T>* = nullptr> + void set_distribution_params(Type_enum type_or_group, + distribution_id dist, + typename T::rep lower_bound, + typename T::rep upper_bound) + { + for (auto tid : get_type_or_group(static_cast(type_or_group))) { + int_params[tid] = { + dist, static_cast(lower_bound), static_cast(upper_bound)}; + } + } + void set_bool_probability(double p) { bool_probability = p; } - void set_null_frequency(double f) { null_frequency = f; } + void set_null_frequency(std::optional f) { null_frequency = f; } void set_cardinality(cudf::size_type c) { cardinality = c; } void set_avg_run_length(cudf::size_type avg_rl) { avg_run_length = avg_rl; } @@ -399,14 +410,15 @@ std::unique_ptr create_random_table(std::vector cons * @param dtype_ids Vector of requested column types * @param num_rows Number of rows in the output table * @param null_probability optional, probability of a null value - * <0 implies no null mask, =0 implies all valids, >=1 implies all nulls + * no value implies no null mask, =0 implies all valids, >=1 implies all nulls * @param seed optional, seed for the pseudo-random engine * @return A table with the sequence columns. */ -std::unique_ptr create_sequence_table(std::vector const& dtype_ids, - row_count num_rows, - float null_probability = -1.0, - unsigned seed = 1); +std::unique_ptr create_sequence_table( + std::vector const& dtype_ids, + row_count num_rows, + std::optional null_probability = std::nullopt, + unsigned seed = 1); /** * @brief Repeats the input data types cyclically to fill a vector of @ref num_cols @@ -423,10 +435,9 @@ std::vector cycle_dtypes(std::vector const& dtype_ * * @param size number of rows * @param null_probability probability of a null value - * <0 implies no null mask, =0 implies all valids, >=1 implies all nulls + * no value implies no null mask, =0 implies all valids, >=1 implies all nulls * @param seed optional, seed for the pseudo-random engine * @return null mask device buffer with random null mask data and null count */ -std::pair create_random_null_mask(cudf::size_type size, - float null_probability, - unsigned seed = 1); +std::pair create_random_null_mask( + cudf::size_type size, std::optional null_probability = std::nullopt, unsigned seed = 1); diff --git a/cpp/benchmarks/common/generate_nullmask.cu b/cpp/benchmarks/common/generate_nullmask.cu deleted file mode 100644 index 502af95a971..00000000000 --- a/cpp/benchmarks/common/generate_nullmask.cu +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "generate_input.hpp" - -#include -#include - -#include - -/** - * @brief bool generator with given probability [0.0 - 1.0] of returning true. - * - */ -struct bool_generator { - thrust::minstd_rand engine; - thrust::uniform_real_distribution dist; - float probability_true; - bool_generator(unsigned seed, float probability_true) - : engine(seed), dist{0, 1}, probability_true{probability_true} - { - } - - __device__ bool operator()(size_t n) - { - engine.discard(n); - return dist(engine) < probability_true; - } -}; - -std::pair create_random_null_mask(cudf::size_type size, - float null_probability, - unsigned seed) -{ - if (null_probability < 0.0f) { - return {rmm::device_buffer{}, 0}; - } else if (null_probability == 0.0f) { - return {cudf::create_null_mask(size, cudf::mask_state::ALL_NULL), size}; - } else if (null_probability >= 1.0f) { - return {cudf::create_null_mask(size, cudf::mask_state::ALL_VALID), 0}; - } else { - return cudf::detail::valid_if(thrust::make_counting_iterator(0), - thrust::make_counting_iterator(size), - bool_generator{seed, 1.0f - null_probability}); - } -}; diff --git a/cpp/benchmarks/common/random_distribution_factory.cuh b/cpp/benchmarks/common/random_distribution_factory.cuh new file mode 100644 index 00000000000..0f508e9685b --- /dev/null +++ b/cpp/benchmarks/common/random_distribution_factory.cuh @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "generate_input.hpp" + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +/** + * @brief Real Type that has atleast number of bits of integral type in its mantissa. + * number of bits of integrals < 23 bits of mantissa in float + * to allow full range of integer bits to be generated. + * @tparam T integral type + */ +template +using integral_to_realType = + std::conditional_t, + T, + std::conditional_t>; + +/** + * @brief Generates a normal distribution between zero and upper_bound. + */ +template +auto make_normal_dist(T lower_bound, T upper_bound) +{ + using realT = integral_to_realType; + T const mean = lower_bound + (upper_bound - lower_bound) / 2; + T const stddev = (upper_bound - lower_bound) / 6; + return thrust::random::normal_distribution(mean, stddev); +} + +template , T>* = nullptr> +auto make_uniform_dist(T range_start, T range_end) +{ + return thrust::uniform_int_distribution(range_start, range_end); +} + +template ()>* = nullptr> +auto make_uniform_dist(T range_start, T range_end) +{ + return thrust::uniform_real_distribution(range_start, range_end); +} + +template +double geometric_dist_p(T range_size) +{ + constexpr double percentage_in_range = 0.99; + double const p = 1 - exp(log(1 - percentage_in_range) / range_size); + return p ? p : std::numeric_limits::epsilon(); +} + +/** + * @brief Generates a geometric distribution between lower_bound and upper_bound. + * This distribution is an approximation generated using normal distribution. + * + * @tparam T Result type of the number to produce. + */ +template +class geometric_distribution : public thrust::random::normal_distribution> { + using realType = integral_to_realType; + using super_t = thrust::random::normal_distribution; + T _lower_bound; + T _upper_bound; + + public: + using result_type = T; + __host__ __device__ explicit geometric_distribution(T lower_bound, T upper_bound) + : super_t(0, std::labs(upper_bound - lower_bound) / 4.0), + _lower_bound(lower_bound), + _upper_bound(upper_bound) + { + } + + template + __host__ __device__ result_type operator()(UniformRandomNumberGenerator& urng) + { + return _lower_bound < _upper_bound ? std::abs(super_t::operator()(urng)) + _lower_bound + : _lower_bound - std::abs(super_t::operator()(urng)); + } +}; + +template +struct value_generator { + using result_type = T; + + value_generator(T lower_bound, T upper_bound, thrust::minstd_rand& engine, Generator gen) + : lower_bound(std::min(lower_bound, upper_bound)), + upper_bound(std::max(lower_bound, upper_bound)), + engine(engine), + dist(gen) + { + } + + __device__ T operator()(size_t n) + { + engine.discard(n); + if constexpr (cuda::std::is_integral_v && + cuda::std::is_floating_point_v) { + return std::clamp(static_cast(std::round(dist(engine))), lower_bound, upper_bound); + } else { + return std::clamp(dist(engine), lower_bound, upper_bound); + } + // Note: uniform does not need clamp, because already range is guaranteed to be within bounds. + } + + T lower_bound; + T upper_bound; + thrust::minstd_rand engine; + Generator dist; +}; + +template +using distribution_fn = std::function(thrust::minstd_rand&, size_t)>; + +template < + typename T, + std::enable_if_t or cuda::std::is_floating_point_v, T>* = nullptr> +distribution_fn make_distribution(distribution_id dist_id, T lower_bound, T upper_bound) +{ + switch (dist_id) { + case distribution_id::NORMAL: + return [lower_bound, upper_bound, dist = make_normal_dist(lower_bound, upper_bound)]( + thrust::minstd_rand& engine, size_t size) -> rmm::device_uvector { + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::tabulate(thrust::device, + result.begin(), + result.end(), + value_generator{lower_bound, upper_bound, engine, dist}); + return result; + }; + case distribution_id::UNIFORM: + return [lower_bound, upper_bound, dist = make_uniform_dist(lower_bound, upper_bound)]( + thrust::minstd_rand& engine, size_t size) -> rmm::device_uvector { + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::tabulate(thrust::device, + result.begin(), + result.end(), + value_generator{lower_bound, upper_bound, engine, dist}); + return result; + }; + case distribution_id::GEOMETRIC: + // kind of exponential distribution from lower_bound to upper_bound. + return [lower_bound, upper_bound, dist = geometric_distribution(lower_bound, upper_bound)]( + thrust::minstd_rand& engine, size_t size) -> rmm::device_uvector { + rmm::device_uvector result(size, rmm::cuda_stream_default); + thrust::tabulate(thrust::device, + result.begin(), + result.end(), + value_generator{lower_bound, upper_bound, engine, dist}); + return result; + }; + default: CUDF_FAIL("Unsupported probability distribution"); + } +} diff --git a/cpp/benchmarks/common/random_distribution_factory.hpp b/cpp/benchmarks/common/random_distribution_factory.hpp deleted file mode 100644 index f2f3833f15d..00000000000 --- a/cpp/benchmarks/common/random_distribution_factory.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "generate_input.hpp" - -#include -#include - -/** - * @brief Generates a normal(binomial) distribution between zero and upper_bound. - */ -template , T>* = nullptr> -auto make_normal_dist(T upper_bound) -{ - using uT = typename std::make_unsigned::type; - return std::binomial_distribution(upper_bound, 0.5); -} - -/** - * @brief Generates a normal distribution between zero and upper_bound. - */ -template ()>* = nullptr> -auto make_normal_dist(T upper_bound) -{ - T const mean = upper_bound / 2; - T const stddev = upper_bound / 6; - return std::normal_distribution(mean, stddev); -} - -template , T>* = nullptr> -auto make_uniform_dist(T range_start, T range_end) -{ - return std::uniform_int_distribution(range_start, range_end); -} - -template ()>* = nullptr> -auto make_uniform_dist(T range_start, T range_end) -{ - return std::uniform_real_distribution(range_start, range_end); -} - -template -double geometric_dist_p(T range_size) -{ - constexpr double percentage_in_range = 0.99; - double const p = 1 - exp(log(1 - percentage_in_range) / range_size); - return p ? p : std::numeric_limits::epsilon(); -} - -template , T>* = nullptr> -auto make_geometric_dist(T range_start, T range_end) -{ - using uT = typename std::make_unsigned::type; - if (range_start > range_end) std::swap(range_start, range_end); - - uT const range_size = (uT)range_end - (uT)range_start; - return std::geometric_distribution(geometric_dist_p(range_size)); -} - -template ()>* = nullptr> -auto make_geometric_dist(T range_start, T range_end) -{ - long double const range_size = range_end - range_start; - return std::exponential_distribution(geometric_dist_p(range_size)); -} - -template -using distribution_fn = std::function; - -template , T>* = nullptr> -distribution_fn make_distribution(distribution_id did, T lower_bound, T upper_bound) -{ - switch (did) { - case distribution_id::NORMAL: - return [lower_bound, dist = make_normal_dist(upper_bound - lower_bound)]( - std::mt19937& engine) mutable -> T { return dist(engine) + lower_bound; }; - case distribution_id::UNIFORM: - return [dist = make_uniform_dist(lower_bound, upper_bound)]( - std::mt19937& engine) mutable -> T { return dist(engine); }; - case distribution_id::GEOMETRIC: - return [lower_bound, upper_bound, dist = make_geometric_dist(lower_bound, upper_bound)]( - std::mt19937& engine) mutable -> T { - if (lower_bound <= upper_bound) - return dist(engine); - else - return lower_bound - dist(engine) + lower_bound; - }; - default: CUDF_FAIL("Unsupported probability distribution"); - } -} - -template ()>* = nullptr> -distribution_fn make_distribution(distribution_id dist_id, T lower_bound, T upper_bound) -{ - switch (dist_id) { - case distribution_id::NORMAL: - return [lower_bound, dist = make_normal_dist(upper_bound - lower_bound)]( - std::mt19937& engine) mutable -> T { return dist(engine) + lower_bound; }; - case distribution_id::UNIFORM: - return [dist = make_uniform_dist(lower_bound, upper_bound)]( - std::mt19937& engine) mutable -> T { return dist(engine); }; - case distribution_id::GEOMETRIC: - return [lower_bound, upper_bound, dist = make_geometric_dist(lower_bound, upper_bound)]( - std::mt19937& engine) mutable -> T { - if (lower_bound <= upper_bound) - return lower_bound + dist(engine); - else - return lower_bound - dist(engine); - }; - default: CUDF_FAIL("Unsupported random distribution"); - } -} diff --git a/cpp/benchmarks/copying/contiguous_split.cu b/cpp/benchmarks/copying/contiguous_split.cu index bb6a9320c4a..9f691e903f7 100644 --- a/cpp/benchmarks/copying/contiguous_split.cu +++ b/cpp/benchmarks/copying/contiguous_split.cu @@ -14,17 +14,13 @@ * limitations under the License. */ -#include - -#include - -#include - +#include #include #include #include -// to enable, run cmake with -DBUILD_BENCHMARKS=ON +#include +#include template void BM_contiguous_split_common(benchmark::State& state, @@ -48,15 +44,12 @@ void BM_contiguous_split_common(benchmark::State& state, }); } - std::vector> columns(src_cols.size()); - std::transform(src_cols.begin(), src_cols.end(), columns.begin(), [](T& in) { - auto ret = in.release(); + for (auto const& col : src_cols) // computing the null count is not a part of the benchmark's target code path, and we want the // property to be pre-computed so that we measure the performance of only the intended code path - [[maybe_unused]] auto const nulls = ret->null_count(); - return ret; - }); - auto const src_table = cudf::table(std::move(columns)); + [[maybe_unused]] auto const nulls = col->null_count(); + + auto const src_table = cudf::table(std::move(src_cols)); for (auto _ : state) { cuda_event_timer raii(state, true); // flush_l2_cache = true, stream = 0 @@ -81,20 +74,17 @@ void BM_contiguous_split(benchmark::State& state) int64_t const num_rows = total_desired_bytes / (num_cols * el_size); // generate input table - srand(31337); - auto valids = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return true; }); - std::vector> src_cols(num_cols); - for (int idx = 0; idx < num_cols; idx++) { - auto rand_elements = - cudf::detail::make_counting_transform_iterator(0, [](int i) { return rand(); }); - if (include_validity) { - src_cols[idx] = cudf::test::fixed_width_column_wrapper( - rand_elements, rand_elements + num_rows, valids); - } else { - src_cols[idx] = - cudf::test::fixed_width_column_wrapper(rand_elements, rand_elements + num_rows); - } - } + data_profile profile; + if (not include_validity) profile.set_null_frequency(std::nullopt); // <0 means, no null_mask + profile.set_cardinality(0); + auto range = default_range(); + profile.set_distribution_params( + cudf::type_id::INT32, distribution_id::UNIFORM, range.first, range.second); + + auto src_cols = create_random_table(cycle_dtypes({cudf::type_id::INT32}, num_cols), + row_count{static_cast(num_rows)}, + profile) + ->release(); int64_t const total_bytes = total_desired_bytes + @@ -107,12 +97,6 @@ void BM_contiguous_split(benchmark::State& state) class ContiguousSplitStrings : public cudf::benchmark { }; -int rand_range(int r) -{ - return static_cast((static_cast(rand()) / static_cast(RAND_MAX)) * - (float)(r - 1)); -} - void BM_contiguous_split_strings(benchmark::State& state) { int64_t const total_desired_bytes = state.range(0); @@ -128,22 +112,24 @@ void BM_contiguous_split_strings(benchmark::State& state) int64_t const num_rows = col_len_bytes / string_len; // generate input table - srand(31337); - auto valids = cudf::detail::make_counting_transform_iterator( - 0, [](auto i) { return i % 2 == 0 ? true : false; }); - std::vector src_cols; - std::vector one_col(num_rows); + data_profile profile; + profile.set_null_frequency(std::nullopt); // <0 means, no null mask + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_id::INT32, + distribution_id::UNIFORM, + 0, + include_validity ? h_strings.size() * 2 : h_strings.size() - 1); // out of bounds nullified + cudf::test::strings_column_wrapper one_col(h_strings.begin(), h_strings.end()); + std::vector> src_cols(num_cols); for (int64_t idx = 0; idx < num_cols; idx++) { - // fill in a random set of strings - for (int64_t s_idx = 0; s_idx < num_rows; s_idx++) { - one_col[s_idx] = h_strings[rand_range(h_strings.size())]; - } - if (include_validity) { - src_cols.push_back( - cudf::test::strings_column_wrapper(one_col.begin(), one_col.end(), valids)); - } else { - src_cols.push_back(cudf::test::strings_column_wrapper(one_col.begin(), one_col.end())); - } + auto random_indices = create_random_table( + {cudf::type_id::INT32}, row_count{static_cast(num_rows)}, profile); + auto str_table = cudf::gather(cudf::table_view{{one_col}}, + random_indices->get_column(0), + (include_validity ? cudf::out_of_bounds_policy::NULLIFY + : cudf::out_of_bounds_policy::DONT_CHECK)); + src_cols[idx] = std::move(str_table->release()[0]); } int64_t const total_bytes = diff --git a/cpp/benchmarks/copying/gather.cu b/cpp/benchmarks/copying/gather.cu index eaa201a0678..1dd4cefb338 100644 --- a/cpp/benchmarks/copying/gather.cu +++ b/cpp/benchmarks/copying/gather.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,23 +14,15 @@ * limitations under the License. */ -#include +#include +#include +#include #include - -#include -#include -#include -#include -#include - #include -#include -#include - -#include "../fixture/benchmark_fixture.hpp" -#include "../synchronization/synchronization.hpp" +#include +#include class Gather : public cudf::benchmark { }; @@ -41,38 +33,28 @@ void BM_gather(benchmark::State& state) const cudf::size_type source_size{(cudf::size_type)state.range(0)}; const auto n_cols = (cudf::size_type)state.range(1); - // Every element is valid - auto data = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i; }); - // Gather indices - std::vector host_map_data(source_size); - std::iota(host_map_data.begin(), host_map_data.end(), 0); + auto gather_map_table = + create_sequence_table({cudf::type_to_id()}, row_count{source_size}); + auto gather_map = gather_map_table->get_column(0).mutable_view(); if (coalesce) { - std::reverse(host_map_data.begin(), host_map_data.end()); + thrust::reverse( + thrust::device, gather_map.begin(), gather_map.end()); } else { - std::random_shuffle(host_map_data.begin(), host_map_data.end()); + thrust::shuffle(thrust::device, + gather_map.begin(), + gather_map.end(), + thrust::default_random_engine()); } - cudf::test::fixed_width_column_wrapper gather_map(host_map_data.begin(), - host_map_data.end()); - - std::vector> source_column_wrappers; - std::vector source_columns(n_cols); - - std::generate_n(std::back_inserter(source_column_wrappers), n_cols, [=]() { - return cudf::test::fixed_width_column_wrapper(data, data + source_size); - }); - std::transform(source_column_wrappers.begin(), - source_column_wrappers.end(), - source_columns.begin(), - [](auto const& col) { return static_cast(col); }); - - cudf::table_view source_table{source_columns}; + // Every element is valid + auto source_table = create_sequence_table(cycle_dtypes({cudf::type_to_id()}, n_cols), + row_count{source_size}); for (auto _ : state) { cuda_event_timer raii(state, true); // flush_l2_cache = true, stream = 0 - cudf::gather(source_table, gather_map); + cudf::gather(*source_table, gather_map); } state.SetBytesProcessed(state.iterations() * state.range(0) * n_cols * 2 * sizeof(TypeParam)); diff --git a/cpp/benchmarks/copying/shift.cu b/cpp/benchmarks/copying/shift.cu index 42d8b58aca3..87718029cb2 100644 --- a/cpp/benchmarks/copying/shift.cu +++ b/cpp/benchmarks/copying/shift.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,22 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include #include -#include -#include - -#include - -#include -#include -#include -#include - -#include template > std::unique_ptr make_scalar( @@ -64,15 +54,12 @@ static void BM_shift(benchmark::State& state) { cudf::size_type size = state.range(0); cudf::size_type offset = size * (static_cast(shift_factor) / 100.0); - auto idx_begin = thrust::make_counting_iterator(0); - auto idx_end = thrust::make_counting_iterator(size); - - auto input = use_validity - ? cudf::test::fixed_width_column_wrapper( - idx_begin, - idx_end, - thrust::make_transform_iterator(idx_begin, [](auto idx) { return true; })) - : cudf::test::fixed_width_column_wrapper(idx_begin, idx_end); + + auto const input_table = + create_sequence_table({cudf::type_to_id()}, + row_count{size}, + use_validity ? std::optional{1.0} : std::nullopt); + cudf::column_view input{input_table->get_column(0)}; auto fill = use_validity ? make_scalar() : make_scalar(777); diff --git a/cpp/benchmarks/filling/repeat.cpp b/cpp/benchmarks/filling/repeat.cpp index 3cedd55767d..a73513e80af 100644 --- a/cpp/benchmarks/filling/repeat.cpp +++ b/cpp/benchmarks/filling/repeat.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,11 @@ * limitations under the License. */ -#include +#include +#include +#include #include -#include - -#include -#include - -#include - -#include - -#include "../fixture/benchmark_fixture.hpp" -#include "../synchronization/synchronization.hpp" class Repeat : public cudf::benchmark { }; @@ -35,39 +26,24 @@ class Repeat : public cudf::benchmark { template void BM_repeat(benchmark::State& state) { - using column_wrapper = cudf::test::fixed_width_column_wrapper; - auto const n_rows = static_cast(state.range(0)); - auto const n_cols = static_cast(state.range(1)); - - auto idx_begin = thrust::make_counting_iterator(0); - auto idx_end = thrust::make_counting_iterator(n_rows); + auto const n_rows = static_cast(state.range(0)); + auto const n_cols = static_cast(state.range(1)); - std::vector columns; - columns.reserve(n_rows); - std::generate_n(std::back_inserter(columns), n_cols, [&]() { - return nulls ? column_wrapper( - idx_begin, - idx_end, - thrust::make_transform_iterator(idx_begin, [](auto idx) { return true; })) - : column_wrapper(idx_begin, idx_end); - }); + auto const input_table = + create_sequence_table(cycle_dtypes({cudf::type_to_id()}, n_cols), + row_count{n_rows}, + nulls ? std::optional{1.0} : std::nullopt); + // Create table view + auto input = cudf::table_view(*input_table); // repeat counts - std::default_random_engine generator; - std::uniform_int_distribution distribution(0, 3); - - std::vector host_repeat_count(n_rows); - std::generate( - host_repeat_count.begin(), host_repeat_count.end(), [&] { return distribution(generator); }); - - cudf::test::fixed_width_column_wrapper repeat_count(host_repeat_count.begin(), - host_repeat_count.end()); - - // Create column views - auto const column_views = std::vector(columns.begin(), columns.end()); - - // Create table view - auto input = cudf::table_view(column_views); + using sizeT = cudf::size_type; + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 3); + auto repeat_table = create_random_table({cudf::type_to_id()}, row_count{n_rows}, profile); + cudf::column_view repeat_count{repeat_table->get_column(0)}; // warm up auto output = cudf::repeat(input, repeat_count); diff --git a/cpp/benchmarks/groupby/group_no_requests.cu b/cpp/benchmarks/groupby/group_no_requests.cu index 750e0c6d3b3..4639a1b8982 100644 --- a/cpp/benchmarks/groupby/group_no_requests.cu +++ b/cpp/benchmarks/groupby/group_no_requests.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -22,32 +23,27 @@ #include #include #include -#include - -#include - -#include class Groupby : public cudf::benchmark { }; void BM_basic_no_requests(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); - - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); std::vector requests; for (auto _ : state) { cuda_event_timer timer(state, true); - cudf::groupby::groupby gb_obj(cudf::table_view({keys})); + cudf::groupby::groupby gb_obj(*keys_table); auto result = gb_obj.aggregate(requests); } } @@ -67,21 +63,18 @@ BENCHMARK_REGISTER_F(Groupby, BasicNoRequest) void BM_pre_sorted_no_requests(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); - auto valid_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100) < 90; }); - - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size, valid_it); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); - auto keys_table = cudf::table_view({keys}); - auto sort_order = cudf::sorted_order(keys_table); - auto sorted_keys = cudf::gather(keys_table, *sort_order); + auto sort_order = cudf::sorted_order(*keys_table); + auto sorted_keys = cudf::gather(*keys_table, *sort_order); // No need to sort values using sort_order because they were generated randomly std::vector requests; diff --git a/cpp/benchmarks/groupby/group_nth.cu b/cpp/benchmarks/groupby/group_nth.cu index daeb88f6dee..f574dd4f64a 100644 --- a/cpp/benchmarks/groupby/group_nth.cu +++ b/cpp/benchmarks/groupby/group_nth.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -22,31 +23,28 @@ #include #include #include -#include - -#include - -#include class Groupby : public cudf::benchmark { }; void BM_pre_sorted_nth(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - // const cudf::size_type num_columns{(cudf::size_type)state.range(0)}; const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); - - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); - auto keys_table = cudf::table_view({keys}); - auto sort_order = cudf::sorted_order(keys_table); - auto sorted_keys = cudf::gather(keys_table, *sort_order); + cudf::column_view vals(vals_table->get_column(0)); + auto sort_order = cudf::sorted_order(*keys_table); + auto sorted_keys = cudf::gather(*keys_table, *sort_order); // No need to sort values using sort_order because they were generated randomly cudf::groupby::groupby gb_obj(*sorted_keys, cudf::null_policy::EXCLUDE, cudf::sorted::YES); diff --git a/cpp/benchmarks/groupby/group_scan.cu b/cpp/benchmarks/groupby/group_scan.cu index 9a6d7b51429..7ccf082a3ba 100644 --- a/cpp/benchmarks/groupby/group_scan.cu +++ b/cpp/benchmarks/groupby/group_scan.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -24,22 +25,25 @@ #include #include -#include - class Groupby : public cudf::benchmark { }; void BM_basic_sum_scan(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size); + cudf::column_view keys(keys_table->get_column(0)); + cudf::column_view vals(vals_table->get_column(0)); cudf::groupby::groupby gb_obj(cudf::table_view({keys, keys, keys})); @@ -66,21 +70,23 @@ BENCHMARK_REGISTER_F(Groupby, BasicSumScan) void BM_pre_sorted_sum_scan(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); - auto valid_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100) < 90; }); - - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size, valid_it); - - auto keys_table = cudf::table_view({keys}); - auto sort_order = cudf::sorted_order(keys_table); - auto sorted_keys = cudf::gather(keys_table, *sort_order); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + profile.set_null_frequency(0.1); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + + cudf::column_view vals(vals_table->get_column(0)); + + auto sort_order = cudf::sorted_order(*keys_table); + auto sorted_keys = cudf::gather(*keys_table, *sort_order); // No need to sort values using sort_order because they were generated randomly cudf::groupby::groupby gb_obj(*sorted_keys, cudf::null_policy::EXCLUDE, cudf::sorted::YES); diff --git a/cpp/benchmarks/groupby/group_shift.cu b/cpp/benchmarks/groupby/group_shift.cu index 29bc99f6b61..d9617deb269 100644 --- a/cpp/benchmarks/groupby/group_shift.cu +++ b/cpp/benchmarks/groupby/group_shift.cu @@ -14,35 +14,36 @@ * limitations under the License. */ +#include #include #include #include -#include #include #include #include #include -#include - class Groupby : public cudf::benchmark { }; void BM_group_shift(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; const int num_groups = 100; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [](cudf::size_type row) { return random_int(0, num_groups); }); + data_profile profile; + profile.set_null_frequency(0.01); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, num_groups); - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); - cudf::groupby::groupby gb_obj(cudf::table_view({keys})); + cudf::groupby::groupby gb_obj(*keys_table); std::vector offsets{ static_cast(column_size / float(num_groups) * 0.5)}; // forward shift half way @@ -53,7 +54,7 @@ void BM_group_shift(benchmark::State& state) for (auto _ : state) { cuda_event_timer timer(state, true); - auto result = gb_obj.shift(cudf::table_view{{vals}}, offsets, {*fill_value}); + auto result = gb_obj.shift(*vals_table, offsets, {*fill_value}); } } diff --git a/cpp/benchmarks/groupby/group_struct.cu b/cpp/benchmarks/groupby/group_struct.cu index 34f2d1adc75..c5eceda2df2 100644 --- a/cpp/benchmarks/groupby/group_struct.cu +++ b/cpp/benchmarks/groupby/group_struct.cu @@ -18,15 +18,10 @@ #include #include -#include - #include #include #include #include -#include - -#include static constexpr cudf::size_type num_struct_members = 8; static constexpr cudf::size_type max_int = 100; diff --git a/cpp/benchmarks/groupby/group_sum.cu b/cpp/benchmarks/groupby/group_sum.cu index 4a33ddeacd4..4dda47a7bc1 100644 --- a/cpp/benchmarks/groupby/group_sum.cu +++ b/cpp/benchmarks/groupby/group_sum.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -22,26 +23,26 @@ #include #include #include -#include - -#include - -#include class Groupby : public cudf::benchmark { }; void BM_basic_sum(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size); + cudf::column_view keys(keys_table->get_column(0)); + cudf::column_view vals(vals_table->get_column(0)); cudf::groupby::groupby gb_obj(cudf::table_view({keys, keys, keys})); @@ -69,21 +70,23 @@ BENCHMARK_REGISTER_F(Groupby, Basic) void BM_pre_sorted_sum(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100); }); - auto valid_it = cudf::detail::make_counting_transform_iterator( - 0, [=](cudf::size_type row) { return random_int(0, 100) < 90; }); - - wrapper keys(data_it, data_it + column_size); - wrapper vals(data_it, data_it + column_size, valid_it); - - auto keys_table = cudf::table_view({keys}); - auto sort_order = cudf::sorted_order(keys_table); - auto sorted_keys = cudf::gather(keys_table, *sort_order); + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + profile.set_null_frequency(0.1); + auto vals_table = + create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + + cudf::column_view vals(vals_table->get_column(0)); + + auto sort_order = cudf::sorted_order(*keys_table); + auto sorted_keys = cudf::gather(*keys_table, *sort_order); // No need to sort values using sort_order because they were generated randomly cudf::groupby::groupby gb_obj(*sorted_keys, cudf::null_policy::EXCLUDE, cudf::sorted::YES); diff --git a/cpp/benchmarks/hashing/partition.cpp b/cpp/benchmarks/hashing/partition.cpp index 185f19f28e5..a15cc2d0f5b 100644 --- a/cpp/benchmarks/hashing/partition.cpp +++ b/cpp/benchmarks/hashing/partition.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,17 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#include #include #include -#include -#include -#include +#include -using cudf::test::fixed_width_column_wrapper; +#include class Hashing : public cudf::benchmark { }; @@ -36,18 +33,9 @@ void BM_hash_partition(benchmark::State& state) auto const num_partitions = state.range(2); // Create owning columns - std::vector> columns(num_cols); - std::generate(columns.begin(), columns.end(), [num_rows]() { - auto iter = thrust::make_counting_iterator(0); - return fixed_width_column_wrapper(iter, iter + num_rows); - }); - - // Create table view into columns - std::vector views(columns.size()); - std::transform(columns.begin(), columns.end(), views.begin(), [](auto const& col) { - return static_cast(col); - }); - auto input = cudf::table_view(views); + auto input_table = create_sequence_table(cycle_dtypes({cudf::type_to_id()}, num_cols), + row_count{static_cast(num_rows)}); + auto input = cudf::table_view(*input_table); auto columns_to_hash = std::vector(num_cols); std::iota(columns_to_hash.begin(), columns_to_hash.end(), 0); diff --git a/cpp/benchmarks/merge/merge.cpp b/cpp/benchmarks/merge/merge.cpp index 1af0fcbb237..88354bcc731 100644 --- a/cpp/benchmarks/merge/merge.cpp +++ b/cpp/benchmarks/merge/merge.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ void BM_merge(benchmark::State& state) for (int i = 0; i < num_tables; ++i) { cudf::size_type const rows = std::round(table_size_dist(rand_gen)); // Ensure size in range [0, avg_rows*2] - auto const clamped_rows = std::max(std::min(rows, avg_rows * 2), 0); + auto const clamped_rows = std::clamp(rows, 0, avg_rows * 2); int32_t prev_key = 0; auto key_sequence = cudf::detail::make_counting_transform_iterator(0, [&](auto row) { diff --git a/cpp/benchmarks/quantiles/quantiles.cpp b/cpp/benchmarks/quantiles/quantiles.cpp index 3ecb436d7fa..cc7dfa08c59 100644 --- a/cpp/benchmarks/quantiles/quantiles.cpp +++ b/cpp/benchmarks/quantiles/quantiles.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,12 @@ * limitations under the License. */ -#include - -#include -#include -#include -#include -#include - -#include #include #include #include +#include + #include class Quantiles : public cudf::benchmark { @@ -34,32 +27,22 @@ class Quantiles : public cudf::benchmark { static void BM_quantiles(benchmark::State& state, bool nulls) { - using Type = int; - using column_wrapper = cudf::test::fixed_width_column_wrapper; - std::default_random_engine generator; - std::uniform_int_distribution distribution(0, 100); + using Type = int; const cudf::size_type n_rows{(cudf::size_type)state.range(0)}; const cudf::size_type n_cols{(cudf::size_type)state.range(1)}; const cudf::size_type n_quantiles{(cudf::size_type)state.range(2)}; // Create columns with values in the range [0,100) - std::vector columns; - columns.reserve(n_cols); - std::generate_n(std::back_inserter(columns), n_cols, [&, n_rows]() { - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&](auto row) { return distribution(generator); }); - if (!nulls) return column_wrapper(elements, elements + n_rows); - auto valids = cudf::detail::make_counting_transform_iterator( - 0, [](auto i) { return i % 100 == 0 ? false : true; }); - return column_wrapper(elements, elements + n_rows, valids); - }); - - // Create column views - auto column_views = std::vector(columns.begin(), columns.end()); + data_profile profile; + profile.set_null_frequency(nulls ? std::optional{0.01} + : std::nullopt); // 1% nulls or no null mask (<0) + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); - // Create table view - auto input = cudf::table_view(column_views); + auto input_table = create_random_table( + cycle_dtypes({cudf::type_to_id()}, n_cols), row_count{n_rows}, profile); + auto input = cudf::table_view(*input_table); std::vector q(n_quantiles); thrust::tabulate( diff --git a/cpp/benchmarks/reduction/anyall.cpp b/cpp/benchmarks/reduction/anyall.cpp index d16292655ce..74304c77f32 100644 --- a/cpp/benchmarks/reduction/anyall.cpp +++ b/cpp/benchmarks/reduction/anyall.cpp @@ -14,15 +14,15 @@ * limitations under the License. */ +#include #include #include + #include #include #include -#include -#include -#include +#include class Reduction : public cudf::benchmark { }; @@ -32,13 +32,15 @@ void BM_reduction_anyall(benchmark::State& state, std::unique_ptr const& agg) { const cudf::size_type column_size{static_cast(state.range(0))}; - - cudf::test::UniformRandomGenerator rand_gen( - (agg->kind == cudf::aggregation::ALL ? 1 : 0), (agg->kind == cudf::aggregation::ANY ? 0 : 100)); - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](cudf::size_type row) { return rand_gen.generate(); }); - cudf::test::fixed_width_column_wrapper values( - data_it, data_it + column_size); + auto const dtype = cudf::type_to_id(); + data_profile profile; + if (agg->kind == cudf::aggregation::ANY) + profile.set_distribution_params(dtype, distribution_id::UNIFORM, 0, 0); + else + profile.set_distribution_params(dtype, distribution_id::UNIFORM, 0, 100); + auto const table = create_random_table({dtype}, row_count{column_size}, profile); + table->get_column(0).set_null_mask(rmm::device_buffer{}, 0); + cudf::column_view values(table->view().column(0)); cudf::data_type output_dtype{cudf::type_id::BOOL8}; diff --git a/cpp/benchmarks/reduction/dictionary.cpp b/cpp/benchmarks/reduction/dictionary.cpp index 01fb17e31f0..cdb6e311302 100644 --- a/cpp/benchmarks/reduction/dictionary.cpp +++ b/cpp/benchmarks/reduction/dictionary.cpp @@ -14,15 +14,14 @@ * limitations under the License. */ +#include #include #include -#include + +#include #include #include -#include -#include - -#include +#include class ReductionDictionary : public cudf::benchmark { }; @@ -33,12 +32,17 @@ void BM_reduction_dictionary(benchmark::State& state, { const cudf::size_type column_size{static_cast(state.range(0))}; - cudf::test::UniformRandomGenerator rand_gen( - (agg->kind == cudf::aggregation::ALL ? 1 : 0), (agg->kind == cudf::aggregation::ANY ? 0 : 100)); - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](cudf::size_type row) { return rand_gen.generate(); }); - cudf::test::dictionary_column_wrapper values( - data_it, data_it + column_size); + // int column and encoded dictionary column + data_profile profile; + profile.set_null_frequency(std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), + distribution_id::UNIFORM, + (agg->kind == cudf::aggregation::ALL ? 1 : 0), + (agg->kind == cudf::aggregation::ANY ? 0 : 100)); + auto int_table = create_random_table({cudf::type_to_id()}, row_count{column_size}, profile); + auto number_col = cudf::cast(int_table->get_column(0), cudf::data_type{cudf::type_to_id()}); + auto values = cudf::dictionary::encode(*number_col); cudf::data_type output_dtype = [&] { if (agg->kind == cudf::aggregation::ANY || agg->kind == cudf::aggregation::ALL) @@ -49,7 +53,7 @@ void BM_reduction_dictionary(benchmark::State& state, for (auto _ : state) { cuda_event_timer timer(state, true); - auto result = cudf::reduce(values, agg, output_dtype); + auto result = cudf::reduce(*values, agg, output_dtype); } } diff --git a/cpp/benchmarks/reduction/minmax.cpp b/cpp/benchmarks/reduction/minmax.cpp index 3b64202eef5..71a92e3498f 100644 --- a/cpp/benchmarks/reduction/minmax.cpp +++ b/cpp/benchmarks/reduction/minmax.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,13 @@ * limitations under the License. */ +#include +#include +#include + #include #include #include -#include -#include -#include -#include - -#include -#include class Reduction : public cudf::benchmark { }; @@ -32,14 +29,10 @@ template void BM_reduction(benchmark::State& state) { const cudf::size_type column_size{(cudf::size_type)state.range(0)}; - - cudf::test::UniformRandomGenerator rand_gen(0, 100); - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](cudf::size_type row) { return rand_gen.generate(); }); - cudf::test::fixed_width_column_wrapper values( - data_it, data_it + column_size); - - auto input_column = cudf::column_view(values); + auto const dtype = cudf::type_to_id(); + auto const table = create_random_table({dtype}, row_count{column_size}); + table->get_column(0).set_null_mask(rmm::device_buffer{}, 0); + cudf::column_view input_column(table->view().column(0)); for (auto _ : state) { cuda_event_timer timer(state, true); diff --git a/cpp/benchmarks/reduction/reduce.cpp b/cpp/benchmarks/reduction/reduce.cpp index d7350ace65c..d24c9009ccf 100644 --- a/cpp/benchmarks/reduction/reduce.cpp +++ b/cpp/benchmarks/reduction/reduce.cpp @@ -14,17 +14,16 @@ * limitations under the License. */ +#include +#include +#include + #include #include #include #include -#include -#include -#include -#include #include -#include class Reduction : public cudf::benchmark { }; @@ -33,14 +32,13 @@ template void BM_reduction(benchmark::State& state, std::unique_ptr const& agg) { const cudf::size_type column_size{(cudf::size_type)state.range(0)}; + auto const dtype = cudf::type_to_id(); + data_profile profile; + profile.set_distribution_params(dtype, distribution_id::UNIFORM, 0, 100); + auto const table = create_random_table({dtype}, row_count{column_size}, profile); + table->get_column(0).set_null_mask(rmm::device_buffer{}, 0); + cudf::column_view input_column(table->view().column(0)); - cudf::test::UniformRandomGenerator rand_gen(0, 100); - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](cudf::size_type row) { return rand_gen.generate(); }); - cudf::test::fixed_width_column_wrapper values( - data_it, data_it + column_size); - - auto input_column = cudf::column_view(values); cudf::data_type output_dtype = (agg->kind == cudf::aggregation::MEAN || agg->kind == cudf::aggregation::VARIANCE || agg->kind == cudf::aggregation::STD) diff --git a/cpp/benchmarks/search/search.cpp b/cpp/benchmarks/search/search.cpp index 0bccbbaff54..6bc509c8746 100644 --- a/cpp/benchmarks/search/search.cpp +++ b/cpp/benchmarks/search/search.cpp @@ -18,17 +18,12 @@ #include #include -#include -#include - #include #include #include #include #include -#include - class Search : public cudf::benchmark { }; @@ -75,38 +70,28 @@ BENCHMARK_REGISTER_F(Search, Column_Nulls) void BM_table(benchmark::State& state) { - using wrapper = cudf::test::fixed_width_column_wrapper; + using Type = float; auto const num_columns{static_cast(state.range(0))}; auto const column_size{static_cast(state.range(1))}; auto const values_size = column_size; - auto make_table = [&](cudf::size_type col_size) { - cudf::test::UniformRandomGenerator random_gen(0, 100); - auto data_it = cudf::detail::make_counting_transform_iterator( - 0, [&](cudf::size_type row) { return random_gen.generate(); }); - auto valid_it = cudf::detail::make_counting_transform_iterator( - 0, [&](cudf::size_type row) { return random_gen.generate() < 90; }); - - std::vector> cols; - for (cudf::size_type i = 0; i < num_columns; i++) { - wrapper temp(data_it, data_it + col_size, valid_it); - cols.emplace_back(temp.release()); - } - - return cudf::table(std::move(cols)); - }; - - auto data_table = make_table(column_size); - auto values_table = make_table(values_size); + data_profile profile; + profile.set_cardinality(0); + profile.set_null_frequency(0.1); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto data_table = create_random_table( + cycle_dtypes({cudf::type_to_id()}, num_columns), row_count{column_size}, profile); + auto values_table = create_random_table( + cycle_dtypes({cudf::type_to_id()}, num_columns), row_count{values_size}, profile); std::vector orders(num_columns, cudf::order::ASCENDING); std::vector null_orders(num_columns, cudf::null_order::BEFORE); - auto sorted = cudf::sort(data_table); + auto sorted = cudf::sort(*data_table); for (auto _ : state) { cuda_event_timer timer(state, true); - auto col = cudf::lower_bound(sorted->view(), values_table, orders, null_orders); + auto col = cudf::lower_bound(sorted->view(), *values_table, orders, null_orders); } } diff --git a/cpp/benchmarks/sort/rank.cpp b/cpp/benchmarks/sort/rank.cpp index 826740dae55..22acb241f0b 100644 --- a/cpp/benchmarks/sort/rank.cpp +++ b/cpp/benchmarks/sort/rank.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "cudf/column/column_view.hpp" #include #include @@ -32,22 +33,16 @@ class Rank : public cudf::benchmark { static void BM_rank(benchmark::State& state, bool nulls) { - using Type = int; - using column_wrapper = cudf::test::fixed_width_column_wrapper; - std::default_random_engine generator; - std::uniform_int_distribution distribution(0, 100); - + using Type = int; const cudf::size_type n_rows{(cudf::size_type)state.range(0)}; // Create columns with values in the range [0,100) - column_wrapper input = [&, n_rows]() { - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&](auto row) { return distribution(generator); }); - if (!nulls) return column_wrapper(elements, elements + n_rows); - auto valids = cudf::detail::make_counting_transform_iterator( - 0, [](auto i) { return i % 100 == 0 ? false : true; }); - return column_wrapper(elements, elements + n_rows, valids); - }(); + data_profile profile; + profile.set_null_frequency(nulls ? std::optional{0.01} : std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto keys_table = create_random_table({cudf::type_to_id()}, row_count{n_rows}, profile); + cudf::column_view input{keys_table->get_column(0)}; for (auto _ : state) { cuda_event_timer raii(state, true, rmm::cuda_stream_default); diff --git a/cpp/benchmarks/sort/sort.cpp b/cpp/benchmarks/sort/sort.cpp index e4c1af159aa..1a42daa5bb0 100644 --- a/cpp/benchmarks/sort/sort.cpp +++ b/cpp/benchmarks/sort/sort.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,31 +34,18 @@ class Sort : public cudf::benchmark { template static void BM_sort(benchmark::State& state, bool nulls) { - using Type = int; - using column_wrapper = cudf::test::fixed_width_column_wrapper; - std::default_random_engine generator; - std::uniform_int_distribution distribution(0, 100); - + using Type = int; const cudf::size_type n_rows{(cudf::size_type)state.range(0)}; const cudf::size_type n_cols{(cudf::size_type)state.range(1)}; - // Create columns with values in the range [0,100) - std::vector columns; - columns.reserve(n_cols); - std::generate_n(std::back_inserter(columns), n_cols, [&, n_rows]() { - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&](auto row) { return distribution(generator); }); - if (!nulls) return column_wrapper(elements, elements + n_rows); - auto valids = cudf::detail::make_counting_transform_iterator( - 0, [](auto i) { return i % 100 == 0 ? false : true; }); - return column_wrapper(elements, elements + n_rows, valids); - }); - - // Create column views - auto column_views = std::vector(columns.begin(), columns.end()); - - // Create table view - auto input = cudf::table_view(column_views); + // Create table with values in the range [0,100) + data_profile profile; + profile.set_null_frequency(nulls ? std::optional{0.01} : std::nullopt); + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); + auto input_table = create_random_table( + cycle_dtypes({cudf::type_to_id()}, n_cols), row_count{n_rows}, profile); + cudf::table_view input{*input_table}; for (auto _ : state) { cuda_event_timer raii(state, true, rmm::cuda_stream_default); diff --git a/cpp/benchmarks/stream_compaction/apply_boolean_mask.cpp b/cpp/benchmarks/stream_compaction/apply_boolean_mask.cpp index 7246d113ade..f2adb18b2b3 100644 --- a/cpp/benchmarks/stream_compaction/apply_boolean_mask.cpp +++ b/cpp/benchmarks/stream_compaction/apply_boolean_mask.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,11 @@ * limitations under the License. */ +#include #include #include #include -#include -#include - -#include -#include -#include namespace { @@ -46,16 +41,6 @@ void size_range(benchmark::internal::Benchmark* b) b->Args({size, fifty_percent}); } -template -T random_int(T min, T max) -{ - static unsigned const seed = 13377331; - static std::mt19937 engine{seed}; - static std::uniform_int_distribution uniform{min, max}; - - return uniform(engine); -} - template void calculate_bandwidth(benchmark::State& state, cudf::size_type num_columns) { @@ -88,38 +73,25 @@ void calculate_bandwidth(benchmark::State& state, cudf::size_type num_columns) template void BM_apply_boolean_mask(benchmark::State& state, cudf::size_type num_columns) { - using wrapper = cudf::test::fixed_width_column_wrapper; - using mask_wrapper = cudf::test::fixed_width_column_wrapper; - const cudf::size_type column_size{static_cast(state.range(0))}; const cudf::size_type percent_true{static_cast(state.range(1))}; - std::vector columns; - - std::vector data(column_size); - std::vector validity(column_size, true); - - std::iota(data.begin(), data.end(), 0); - - for (int i = 0; i < num_columns; i++) { - columns.emplace_back(data.cbegin(), data.cend(), validity.cbegin()); - } - - std::vector mask_data(column_size); - std::generate_n( - mask_data.begin(), column_size, [&]() { return random_int(0, 100) < percent_true; }); - mask_wrapper mask(mask_data.begin(), mask_data.end()); + data_profile profile; + profile.set_null_frequency(0.0); // ==0 means, all valid + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); - std::vector column_views(num_columns); + auto source_table = create_random_table( + cycle_dtypes({cudf::type_to_id()}, num_columns), row_count{column_size}, profile); - std::transform(columns.begin(), columns.end(), column_views.begin(), [](auto const& col) { - return static_cast(col); - }); - cudf::table_view source_table{column_views}; + profile.set_bool_probability(percent_true / 100.0); + profile.set_null_frequency(std::nullopt); // <0 means, no null mask + auto mask_table = create_random_table({cudf::type_id::BOOL8}, row_count{column_size}, profile); + cudf::column_view mask = mask_table->get_column(0); for (auto _ : state) { cuda_event_timer raii(state, true); - auto result = cudf::apply_boolean_mask(source_table, mask); + auto result = cudf::apply_boolean_mask(*source_table, mask); } calculate_bandwidth(state, num_columns); diff --git a/cpp/benchmarks/stream_compaction/distinct.cpp b/cpp/benchmarks/stream_compaction/distinct.cpp index 37d90894746..749badc715d 100644 --- a/cpp/benchmarks/stream_compaction/distinct.cpp +++ b/cpp/benchmarks/stream_compaction/distinct.cpp @@ -14,19 +14,15 @@ * limitations under the License. */ +#include +#include + #include #include #include -#include -#include - -#include #include -#include -#include - NVBENCH_DECLARE_TYPE_STRINGS(cudf::timestamp_ms, "cudf::timestamp_ms", "cudf::timestamp_ms"); template @@ -34,16 +30,17 @@ void nvbench_distinct(nvbench::state& state, nvbench::type_list) { cudf::rmm_pool_raii pool_raii; - auto const num_rows = state.get_int64("NumRows"); + cudf::size_type const num_rows = state.get_int64("NumRows"); + + data_profile profile; + profile.set_null_frequency(0.01); + profile.set_cardinality(0); + profile.set_distribution_params(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); - cudf::test::UniformRandomGenerator rand_gen(0, 100); - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](auto row) { return rand_gen.generate(); }); - auto valids = - cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 100 != 0; }); - cudf::test::fixed_width_column_wrapper values(elements, elements + num_rows, valids); + auto source_table = + create_random_table(cycle_dtypes({cudf::type_to_id()}, 1), row_count{num_rows}, profile); - auto input_column = cudf::column_view(values); + auto input_column = cudf::column_view(source_table->get_column(0)); auto input_table = cudf::table_view({input_column, input_column, input_column, input_column}); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { diff --git a/cpp/benchmarks/stream_compaction/unique.cpp b/cpp/benchmarks/stream_compaction/unique.cpp index edc4097e55b..a1fc61eee5d 100644 --- a/cpp/benchmarks/stream_compaction/unique.cpp +++ b/cpp/benchmarks/stream_compaction/unique.cpp @@ -14,19 +14,15 @@ * limitations under the License. */ +#include +#include + #include #include #include -#include -#include - -#include #include -#include -#include - // necessary for custom enum types // see: https://github.com/NVIDIA/nvbench/blob/main/examples/enums.cu NVBENCH_DECLARE_ENUM_TYPE_STRINGS( @@ -56,16 +52,17 @@ void nvbench_unique(nvbench::state& state, nvbench::type_list(cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); - cudf::test::UniformRandomGenerator rand_gen(0, 100); - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&rand_gen](auto row) { return rand_gen.generate(); }); - auto valids = - cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 100 != 0; }); - cudf::test::fixed_width_column_wrapper values(elements, elements + num_rows, valids); + auto source_table = + create_random_table(cycle_dtypes({cudf::type_to_id()}, 1), row_count{num_rows}, profile); - auto input_column = cudf::column_view(values); + auto input_column = cudf::column_view(source_table->get_column(0)); auto input_table = cudf::table_view({input_column, input_column, input_column, input_column}); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { diff --git a/cpp/benchmarks/string/extract.cpp b/cpp/benchmarks/string/extract.cpp index b4034ff054a..b8d206386f5 100644 --- a/cpp/benchmarks/string/extract.cpp +++ b/cpp/benchmarks/string/extract.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ #include "string_bench_args.hpp" -#include #include #include #include @@ -52,11 +51,17 @@ static void BM_extract(benchmark::State& state, int groups) pattern += "(\\d+) "; } - std::uniform_int_distribution distribution(0, samples.size() - 1); - auto elements = cudf::detail::make_counting_transform_iterator( - 0, [&](auto idx) { return samples.at(distribution(generator)); }); - cudf::test::strings_column_wrapper input(elements, elements + n_rows); - cudf::strings_column_view view(input); + cudf::test::strings_column_wrapper samples_column(samples.begin(), samples.end()); + data_profile profile; + profile.set_null_frequency(std::nullopt); // <0 means, all valid + profile.set_distribution_params( + cudf::type_to_id(), distribution_id::UNIFORM, 0, samples.size() - 1); + auto map_table = + create_random_table({cudf::type_to_id()}, row_count{n_rows}, profile); + auto input = cudf::gather(cudf::table_view{{samples_column}}, + map_table->get_column(0).view(), + cudf::out_of_bounds_policy::DONT_CHECK); + cudf::strings_column_view view(input->get_column(0).view()); for (auto _ : state) { cuda_event_timer raii(state, true); diff --git a/cpp/benchmarks/string/json.cu b/cpp/benchmarks/string/json.cu index 69c42f97d7f..9b55375f191 100644 --- a/cpp/benchmarks/string/json.cu +++ b/cpp/benchmarks/string/json.cu @@ -14,20 +14,18 @@ * limitations under the License. */ -#include #include #include #include -#include -#include -#include -#include #include +#include #include #include +#include #include +#include #include @@ -165,7 +163,7 @@ auto build_json_string_column(int desired_bytes, int num_rows) { data_profile profile; profile.set_cardinality(0); - profile.set_null_frequency(-0.1); + profile.set_null_frequency(std::nullopt); profile.set_distribution_params( cudf::type_id::FLOAT32, distribution_id::UNIFORM, 0.0, 1.0); auto float_2bool_columns = diff --git a/cpp/benchmarks/transpose/transpose.cu b/cpp/benchmarks/transpose/transpose.cpp similarity index 67% rename from cpp/benchmarks/transpose/transpose.cu rename to cpp/benchmarks/transpose/transpose.cpp index 31861c12ebe..a164b04f406 100644 --- a/cpp/benchmarks/transpose/transpose.cu +++ b/cpp/benchmarks/transpose/transpose.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,32 +16,25 @@ #include #include + +#include +#include #include -#include -#include -#include -#include -using cudf::test::fixed_width_column_wrapper; +#include +#include static void BM_transpose(benchmark::State& state) { auto count = state.range(0); - - auto data = std::vector(count, 0); - auto validity = std::vector(count, 1); - - auto fwcw_iter = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), [&data, &validity](auto idx) { - return fixed_width_column_wrapper(data.begin(), data.end(), validity.begin()); + auto int_column_generator = + thrust::make_transform_iterator(thrust::counting_iterator(0), [count](int i) { + return cudf::make_numeric_column( + cudf::data_type{cudf::type_id::INT32}, count, cudf::mask_state::ALL_VALID); }); - auto input_columns = std::vector>(fwcw_iter, fwcw_iter + count); - - auto input_column_views = - std::vector(input_columns.begin(), input_columns.end()); - - auto input = cudf::table_view(input_column_views); + auto input_table = cudf::table(std::vector(int_column_generator, int_column_generator + count)); + auto input = input_table.view(); for (auto _ : state) { cuda_event_timer raii(state, true);