diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3501bb9345c..b961080d162 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -251,7 +251,8 @@ add_library(cudf src/io/parquet/parquet.cpp src/io/parquet/reader_impl.cu src/io/parquet/writer_impl.cu - src/io/statistics/column_stats.cu + src/io/statistics/orc_column_statistics.cu + src/io/statistics/parquet_column_statistics.cu src/io/utilities/column_buffer.cpp src/io/utilities/data_sink.cpp src/io/utilities/datasource.cpp diff --git a/cpp/src/io/orc/orc_gpu.h b/cpp/src/io/orc/orc_gpu.h index 38dd69f7b9e..0f277d3d8fa 100644 --- a/cpp/src/io/orc/orc_gpu.h +++ b/cpp/src/io/orc/orc_gpu.h @@ -19,10 +19,10 @@ #include "timezone.cuh" #include -#include #include #include #include +#include #include #include "orc_common.h" diff --git a/cpp/src/io/orc/stats_enc.cu b/cpp/src/io/orc/stats_enc.cu index 56a55bd0a4d..4c85150a9f0 100644 --- a/cpp/src/io/orc/stats_enc.cu +++ b/cpp/src/io/orc/stats_enc.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -311,7 +311,7 @@ __global__ void __launch_bounds__(encode_threads_per_block) // } if (s->chunk.has_sum) { // Sum is equal to the number of 'true' values cur[0] = 5 * 8 + PB_TYPE_FIXEDLEN; - cur = pb_put_packed_uint(cur + 2, 1, s->chunk.sum.i_val); + cur = pb_put_packed_uint(cur + 2, 1, s->chunk.sum.u_val); fld_start[1] = cur - (fld_start + 2); } break; diff --git a/cpp/src/io/orc/writer_impl.cu b/cpp/src/io/orc/writer_impl.cu index b8c608c5714..2aa1e2d866a 100644 --- a/cpp/src/io/orc/writer_impl.cu +++ b/cpp/src/io/orc/writer_impl.cu @@ -21,6 +21,7 @@ #include "writer_impl.hpp" +#include #include #include @@ -851,18 +852,20 @@ std::vector> writer::impl::gather_statistic_blobs( row_index_stride_, stream); - GatherColumnStatistics(stat_chunks.data(), stat_groups.data(), num_chunks, stream); - MergeColumnStatistics(stat_chunks.data() + num_chunks, - stat_chunks.data(), - stat_merge.device_ptr(), - stripe_bounds.size() * columns.size(), - stream); - - MergeColumnStatistics(stat_chunks.data() + num_chunks + stripe_bounds.size() * columns.size(), - stat_chunks.data() + num_chunks, - stat_merge.device_ptr(stripe_bounds.size() * columns.size()), - columns.size(), - stream); + detail::calculate_group_statistics( + stat_chunks.data(), stat_groups.data(), num_chunks, stream); + detail::merge_group_statistics(stat_chunks.data() + num_chunks, + stat_chunks.data(), + stat_merge.device_ptr(), + stripe_bounds.size() * columns.size(), + stream); + + detail::merge_group_statistics( + stat_chunks.data() + num_chunks + stripe_bounds.size() * columns.size(), + stat_chunks.data() + num_chunks, + stat_merge.device_ptr(stripe_bounds.size() * columns.size()), + columns.size(), + stream); gpu::orc_init_statistics_buffersize( stat_merge.device_ptr(), stat_chunks.data() + num_chunks, num_stat_blobs, stream); stat_merge.device_to_host(stream, true); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 564226c7ff3..1b6bb9ad7ca 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -18,7 +18,7 @@ #include "io/comp/gpuinflate.h" #include "io/parquet/parquet_common.hpp" -#include "io/statistics/column_stats.h" +#include "io/statistics/statistics.cuh" #include "io/utilities/column_buffer.hpp" #include "io/utilities/hostdevice_vector.hpp" diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu index 1a17128487a..77210b5a2ab 100644 --- a/cpp/src/io/parquet/writer_impl.cu +++ b/cpp/src/io/parquet/writer_impl.cu @@ -19,6 +19,7 @@ * @brief cuDF-IO parquet writer class implementation */ +#include #include "writer_impl.hpp" #include @@ -738,7 +739,7 @@ void writer::impl::gather_fragment_statistics( device_2dspan(frag_stats_group.data(), num_columns, num_fragments); gpu::InitFragmentStatistics(frag_stats_group_2dview, frag, col_desc, stream); - GatherColumnStatistics( + detail::calculate_group_statistics( frag_stats_chunk.data(), frag_stats_group.data(), num_fragments * num_columns, stream); stream.synchronize(); } @@ -780,13 +781,15 @@ void writer::impl::init_encoder_pages(hostdevice_2dvector & (num_stats_bfr > num_pages) ? page_stats_mrg.data() + num_pages : nullptr, stream); if (num_stats_bfr > 0) { - MergeColumnStatistics(page_stats, frag_stats, page_stats_mrg.data(), num_pages, stream); + detail::merge_group_statistics( + page_stats, frag_stats, page_stats_mrg.data(), num_pages, stream); if (num_stats_bfr > num_pages) { - MergeColumnStatistics(page_stats + num_pages, - page_stats, - page_stats_mrg.data() + num_pages, - num_stats_bfr - num_pages, - stream); + detail::merge_group_statistics( + page_stats + num_pages, + page_stats, + page_stats_mrg.data() + num_pages, + num_stats_bfr - num_pages, + stream); } } stream.synchronize(); diff --git a/cpp/src/io/statistics/column_statistics.cuh b/cpp/src/io/statistics/column_statistics.cuh new file mode 100644 index 00000000000..fd148724712 --- /dev/null +++ b/cpp/src/io/statistics/column_statistics.cuh @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file column_statistics.cuh + * @brief Functors for statistics calculation to be used in ORC and PARQUET + */ + +#pragma once + +#include "temp_storage_wrapper.cuh" + +#include "typed_statistics_chunk.cuh" + +#include "statistics.cuh" + +namespace cudf { +namespace io { + +/** + * @brief shared state for statistics calculation kernel + */ +struct stats_state_s { + stats_column_desc col; ///< Column information + statistics_group group; ///< Group description + statistics_chunk ck; ///< Output statistics chunk +}; + +/** + * @brief shared state for statistics merge kernel + */ +struct merge_state_s { + stats_column_desc col; ///< Column information + statistics_merge_group group; ///< Group description + statistics_chunk ck; ///< Resulting statistics chunk +}; + +template +using block_reduce_storage = detail::block_reduce_storage; + +/** + * @brief Functor to calculate the statistics of rows in a column belonging to a + * statistics group + * + * @tparam block_size Dimension of the block + * @tparam IO File format for which statistics calculation is being done + */ +template +struct calculate_group_statistics_functor { + block_reduce_storage &temp_storage; + + /** + * @brief Construct a statistics calculator + * + * @param d_temp_storage Temporary storage to be used by cub calls + */ + __device__ calculate_group_statistics_functor(block_reduce_storage &d_temp_storage) + : temp_storage(d_temp_storage) + { + } + + template ::is_ignored> * = nullptr> + __device__ void operator()(stats_state_s &s, uint32_t t) + { + // No-op for unsupported aggregation types + } + + /** + * @brief Iterates through the rows specified by statistics group and stores the combined + * statistics into the statistics chunk. + * + * @param s Statistics state which specifies the column, the group being worked and the chunk + * the results will be stored into + * @param t thread id + */ + template ::is_ignored> * = nullptr> + __device__ void operator()(stats_state_s &s, uint32_t t) + { + detail::storage_wrapper storage(temp_storage); + + using type_convert = detail::type_conversion>; + using CT = typename type_convert::template type; + typed_statistics_chunk::is_aggregated> chunk( + s.group.num_rows); + + for (uint32_t i = 0; i < s.group.num_rows; i += block_size) { + uint32_t r = i + t; + uint32_t row = r + s.group.start_row; + auto const is_valid = (r < s.group.num_rows) ? s.col.leaf_column->is_valid(row) : 0; + if (is_valid) { + auto converted_value = type_convert::convert(s.col.leaf_column->element(row)); + chunk.reduce(converted_value); + } + } + + chunk = block_reduce(chunk, storage); + + if (t == 0) { s.ck = get_untyped_chunk(chunk); } + } +}; + +/** + * @brief Functor to merge the statistics chunks of a column belonging to a + * merge group + * + * @tparam block_size Dimension of the block + * @tparam IO File format for which statistics calculation is being done + */ +template +struct merge_group_statistics_functor { + block_reduce_storage &temp_storage; + + __device__ merge_group_statistics_functor(block_reduce_storage &d_temp_storage) + : temp_storage(d_temp_storage) + { + } + + template ::is_ignored> * = nullptr> + __device__ void operator()(merge_state_s &s, + const statistics_chunk *chunks, + const uint32_t num_chunks, + uint32_t t) + { + // No-op for unsupported aggregation types + } + + template ::is_ignored> * = nullptr> + __device__ void operator()(merge_state_s &s, + const statistics_chunk *chunks, + const uint32_t num_chunks, + uint32_t t) + { + detail::storage_wrapper storage(temp_storage); + + typed_statistics_chunk::is_aggregated> chunk; + + for (uint32_t i = t; i < num_chunks; i += block_size) { chunk.reduce(chunks[i]); } + chunk.has_minmax = (chunk.minimum_value <= chunk.maximum_value); + + chunk = block_reduce(chunk, storage); + + if (t == 0) { s.ck = get_untyped_chunk(chunk); } + } +}; + +/** + * @brief Function to cooperatively load an object from a pointer + * + * If the pointer is nullptr then the members of the object are set to 0 + * + * @param[out] destination Object being loaded + * @param[in] source Source object + * @tparam T Type of object + */ +template +__device__ void cooperative_load(T &destination, const T *source = nullptr) +{ + using load_type = std::conditional_t<((sizeof(T) % sizeof(uint32_t)) == 0), uint32_t, uint8_t>; + if (source == nullptr) { + for (auto i = threadIdx.x; i < (sizeof(T) / sizeof(load_type)); i += blockDim.x) { + reinterpret_cast(&destination)[i] = load_type{0}; + } + } else { + for (auto i = threadIdx.x; i < sizeof(T) / sizeof(load_type); i += blockDim.x) { + reinterpret_cast(&destination)[i] = + reinterpret_cast(source)[i]; + } + } +} + +/** + * @brief Kernel to calculate group statistics + * + * @param[out] chunks Statistics results [num_chunks] + * @param[in] groups Statistics row groups [num_chunks] + * @tparam block_size Dimension of the block + * @tparam IO File format for which statistics calculation is being done + */ +template +__global__ void __launch_bounds__(block_size, 1) + gpu_calculate_group_statistics(statistics_chunk *chunks, const statistics_group *groups) +{ + __shared__ __align__(8) stats_state_s state; + __shared__ block_reduce_storage storage; + + // Load state members + cooperative_load(state.group, &groups[blockIdx.x]); + cooperative_load(state.ck); + __syncthreads(); + cooperative_load(state.col, state.group.col); + __syncthreads(); + + // Calculate statistics + type_dispatcher(state.col.leaf_column->type(), + calculate_group_statistics_functor(storage), + state, + threadIdx.x); + __syncthreads(); + + cooperative_load(chunks[blockIdx.x], &state.ck); +} + +namespace detail { + +/** + * @brief Launches kernel to calculate group statistics + * + * @param[out] chunks Statistics results [num_chunks] + * @param[in] groups Statistics row groups [num_chunks] + * @param[in] num_chunks Number of chunks & rowgroups + * @param[in] stream CUDA stream to use + * @tparam IO File format for which statistics calculation is being done + */ +template +void calculate_group_statistics(statistics_chunk *chunks, + const statistics_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream) +{ + constexpr int block_size = 256; + gpu_calculate_group_statistics + <<>>(chunks, groups); +} + +/** + * @brief Kernel to merge column statistics + * + * @param[out] chunks_out Statistics results [num_chunks] + * @param[in] chunks_in Input statistics + * @param[in] groups Statistics groups [num_chunks] + * @tparam block_size Dimension of the block + * @tparam IO File format for which statistics calculation is being done + */ +template +__global__ void __launch_bounds__(block_size, 1) + gpu_merge_group_statistics(statistics_chunk *chunks_out, + const statistics_chunk *chunks_in, + const statistics_merge_group *groups) +{ + __shared__ __align__(8) merge_state_s state; + __shared__ block_reduce_storage storage; + + cooperative_load(state.group, &groups[blockIdx.x]); + __syncthreads(); + cooperative_load(state.col, state.group.col); + __syncthreads(); + + type_dispatcher(state.col.leaf_column->type(), + merge_group_statistics_functor(storage), + state, + chunks_in + state.group.start_chunk, + state.group.num_chunks, + threadIdx.x); + __syncthreads(); + + cooperative_load(chunks_out[blockIdx.x], &state.ck); +} + +/** + * @brief Launches kernel to merge column statistics + * + * @param[out] chunks_out Statistics results [num_chunks] + * @param[in] chunks_in Input statistics + * @param[in] groups Statistics groups [num_chunks] + * @param[in] num_chunks Number of chunks & groups + * @param[in] stream CUDA stream to use + * @tparam IO File format for which statistics calculation is being done + */ +template +void merge_group_statistics(statistics_chunk *chunks_out, + const statistics_chunk *chunks_in, + const statistics_merge_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream) +{ + constexpr int block_size = 256; + gpu_merge_group_statistics + <<>>(chunks_out, chunks_in, groups); +} + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/column_stats.cu b/cpp/src/io/statistics/column_stats.cu deleted file mode 100644 index 52f21f0a9ad..00000000000 --- a/cpp/src/io/statistics/column_stats.cu +++ /dev/null @@ -1,566 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "column_stats.h" - -#include - -#include - -#include - -#include - -constexpr int block_size = 1024; - -namespace cudf { -namespace io { -/** - * @brief shared state for statistics gather kernel - */ -struct stats_state_s { - stats_column_desc col; ///< Column information - statistics_group group; ///< Group description - statistics_chunk ck; ///< Output statistics chunk - volatile statistics_val warp_min[32]; ///< Min reduction scratch - volatile statistics_val warp_max[32]; ///< Max reduction scratch -}; - -/** - * @brief shared state for statistics merge kernel - */ -struct merge_state_s { - stats_column_desc col; ///< Column information - statistics_merge_group group; ///< Group description - statistics_chunk ck; ///< Resulting statistics chunk - volatile statistics_val warp_min[32]; ///< Min reduction scratch - volatile statistics_val warp_max[32]; ///< Max reduction scratch -}; - -/** - * Custom addition functor to ignore NaN inputs - */ -struct IgnoreNaNSum { - __device__ __forceinline__ double operator()(const double &a, const double &b) - { - double aval = isnan(a) ? 0 : a; - double bval = isnan(b) ? 0 : b; - return aval + bval; - } -}; - -/** - * @brief Gather statistics for integer-like columns - * - * @param s shared block state - * @param dtype data type - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ -gatherIntColumnStats(stats_state_s *s, statistics_dtype dtype, uint32_t t, Storage &storage) -{ - using block_reduce = cub::BlockReduce; - int64_t vmin = INT64_MAX; - int64_t vmax = INT64_MIN; - int64_t vsum = 0; - int64_t v; - uint32_t nn_cnt = 0; - __shared__ volatile bool has_minmax; - for (uint32_t i = 0; i < s->group.num_rows; i += block_size) { - uint32_t r = i + t; - uint32_t row = r + s->group.start_row; - uint32_t is_valid = (r < s->group.num_rows) ? s->col.leaf_column->is_valid(row) : 0; - if (is_valid) { - switch (dtype) { - case dtype_int32: - case dtype_date32: v = s->col.leaf_column->element(row); break; - case dtype_int64: - case dtype_decimal64: v = s->col.leaf_column->element(row); break; - case dtype_int16: v = s->col.leaf_column->element(row); break; - case dtype_timestamp64: - v = s->col.leaf_column->element(row); - if (s->col.ts_scale < -1) { - v /= -s->col.ts_scale; - } else if (s->col.ts_scale > 1) { - v *= s->col.ts_scale; - } - break; - default: v = s->col.leaf_column->element(row); break; - } - vmin = min(vmin, v); - vmax = max(vmax, v); - vsum += v; - } - nn_cnt += __syncthreads_count(is_valid); - } - if (!t) { - s->ck.non_nulls = nn_cnt; - s->ck.null_count = s->group.num_rows - nn_cnt; - } - vmin = block_reduce(storage.integer_stats).Reduce(vmin, cub::Min()); - __syncthreads(); - vmax = block_reduce(storage.integer_stats).Reduce(vmax, cub::Max()); - if (!t) { has_minmax = (vmin <= vmax); } - __syncthreads(); - if (has_minmax) { vsum = block_reduce(storage.integer_stats).Sum(vsum); } - if (!t) { - if (has_minmax) { - s->ck.min_value.i_val = vmin; - s->ck.max_value.i_val = vmax; - s->ck.sum.i_val = vsum; - } - s->ck.has_minmax = has_minmax; - // TODO: For now, don't set the sum flag with 64-bit values so we don't have to check for - // 64-bit sum overflow - s->ck.has_sum = (dtype <= dtype_int32 && has_minmax); - } -} - -/** - * @brief Gather statistics for floating-point columns - * - * @param s shared block state - * @param dtype data type - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ -gatherFloatColumnStats(stats_state_s *s, statistics_dtype dtype, uint32_t t, Storage &storage) -{ - using block_reduce = cub::BlockReduce; - double vmin = CUDART_INF; - double vmax = -CUDART_INF; - double vsum = 0; - double v; - uint32_t nn_cnt = 0; - __shared__ volatile bool has_minmax; - for (uint32_t i = 0; i < s->group.num_rows; i += block_size) { - uint32_t r = i + t; - uint32_t row = r + s->group.start_row; - uint32_t is_valid = (r < s->group.num_rows) ? s->col.leaf_column->is_valid(row) : 0; - if (is_valid) { - if (dtype == dtype_float64) { - v = s->col.leaf_column->element(row); - } else { - v = s->col.leaf_column->element(row); - } - vmin = min(vmin, v); - vmax = max(vmax, v); - if (!isnan(v)) { vsum += v; } - } - nn_cnt += __syncthreads_count(is_valid); - } - if (!t) { - s->ck.non_nulls = nn_cnt; - s->ck.null_count = s->group.num_rows - nn_cnt; - } - vmin = block_reduce(storage.float_stats).Reduce(vmin, cub::Min()); - __syncthreads(); - vmax = block_reduce(storage.float_stats).Reduce(vmax, cub::Max()); - if (!t) { has_minmax = (vmin <= vmax); } - __syncthreads(); - if (has_minmax) { vsum = block_reduce(storage.float_stats).Reduce(vsum, IgnoreNaNSum()); } - if (!t) { - if (has_minmax) { - s->ck.min_value.fp_val = (vmin != 0.0) ? vmin : CUDART_NEG_ZERO; - s->ck.max_value.fp_val = (vmax != 0.0) ? vmax : CUDART_ZERO; - s->ck.sum.fp_val = vsum; - } - s->ck.has_minmax = has_minmax; - s->ck.has_sum = has_minmax; // Implies sum is valid as well - } -} - -/** - * @brief Gather statistics for string columns - * - * @param s shared block state - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ gatherStringColumnStats(stats_state_s *s, uint32_t t, Storage &storage) -{ - using block_reduce = cub::BlockReduce; - using string_reduce = cub::BlockReduce; - uint32_t len_sum = 0; - uint32_t nn_cnt = 0; - bool has_minmax = false; - - string_view minimum_value = string_view::max(); - string_view maximum_value = string_view::min(); - - for (uint32_t i = 0; i < s->group.num_rows; i += block_size) { - uint32_t r = i + t; - uint32_t row = r + s->group.start_row; - uint32_t is_valid = (r < s->group.num_rows) ? s->col.leaf_column->is_valid(row) : 0; - if (is_valid) { - has_minmax = true; - auto str = s->col.leaf_column->element(row); - len_sum += str.size_bytes(); - minimum_value = thrust::min(minimum_value, str); - maximum_value = thrust::max(maximum_value, str); - } - nn_cnt += __syncthreads_count(is_valid); - } - if (!t) { - s->ck.non_nulls = nn_cnt; - s->ck.null_count = s->group.num_rows - nn_cnt; - } - minimum_value = string_reduce(storage.string_val_stats).Reduce(minimum_value, cub::Min()); - __syncthreads(); - maximum_value = string_reduce(storage.string_val_stats).Reduce(maximum_value, cub::Max()); - has_minmax = __syncthreads_or(has_minmax); - if (has_minmax) { len_sum = block_reduce(storage.string_stats).Sum(len_sum); } - - if (!t) { - if (has_minmax) { - s->ck.min_value.str_val = minimum_value; - s->ck.max_value.str_val = maximum_value; - s->ck.sum.i_val = len_sum; - } - s->ck.has_minmax = has_minmax; - s->ck.has_sum = has_minmax; - } -} - -/** - * @brief Gather column chunk statistics (min/max values, sum and null count) - * for a group of rows. - * - * blockDim {1024,1,1} - * - * @param chunks Destination statistics results - * @param groups Statistics source information - */ -template -__global__ void __launch_bounds__(block_size, 1) - gpuGatherColumnStatistics(statistics_chunk *chunks, const statistics_group *groups) -{ - __shared__ __align__(8) stats_state_s state_g; - __shared__ union { - typename cub::BlockReduce::TempStorage integer_stats; - typename cub::BlockReduce::TempStorage float_stats; - typename cub::BlockReduce::TempStorage string_stats; - typename cub::BlockReduce::TempStorage string_val_stats; - } temp_storage; - - stats_state_s *const s = &state_g; - uint32_t t = threadIdx.x; - statistics_dtype dtype; - - if (t < sizeof(statistics_group) / sizeof(uint32_t)) { - reinterpret_cast(&s->group)[t] = - reinterpret_cast(&groups[blockIdx.x])[t]; - } - if (t < sizeof(statistics_chunk) / sizeof(uint32_t)) { - reinterpret_cast(&s->ck)[t] = 0; - } - __syncthreads(); - if (t < sizeof(stats_column_desc) / sizeof(uint32_t)) { - reinterpret_cast(&s->col)[t] = reinterpret_cast(s->group.col)[t]; - } - __syncthreads(); - dtype = s->col.stats_dtype; - if (dtype >= dtype_bool && dtype <= dtype_decimal64) { - gatherIntColumnStats(s, dtype, t, temp_storage); - } else if (dtype >= dtype_float32 && dtype <= dtype_float64) { - gatherFloatColumnStats(s, dtype, t, temp_storage); - } else if (dtype == dtype_string) { - gatherStringColumnStats(s, t, temp_storage); - } - __syncthreads(); - if (t < sizeof(statistics_chunk) / sizeof(uint32_t)) { - reinterpret_cast(&chunks[blockIdx.x])[t] = reinterpret_cast(&s->ck)[t]; - } -} - -/** - * @brief Merge statistics for integer-like columns - * - * @param s shared block state - * @param dtype data type - * @param ck_in pointer to first statistic chunk - * @param num_chunks number of statistic chunks to merge - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ mergeIntColumnStats(merge_state_s *s, - statistics_dtype dtype, - const statistics_chunk *ck_in, - uint32_t num_chunks, - uint32_t t, - Storage &storage) -{ - int64_t vmin = INT64_MAX; - int64_t vmax = INT64_MIN; - int64_t vsum = 0; - uint32_t non_nulls = 0; - uint32_t null_count = 0; - __shared__ volatile bool has_minmax; - for (uint32_t i = t; i < num_chunks; i += block_size) { - const statistics_chunk *ck = &ck_in[i]; - if (ck->has_minmax) { - vmin = min(vmin, ck->min_value.i_val); - vmax = max(vmax, ck->max_value.i_val); - } - if (ck->has_sum) { vsum += ck->sum.i_val; } - non_nulls += ck->non_nulls; - null_count += ck->null_count; - } - vmin = cub::BlockReduce(storage.i64).Reduce(vmin, cub::Min()); - __syncthreads(); - vmax = cub::BlockReduce(storage.i64).Reduce(vmax, cub::Max()); - if (!t) { has_minmax = (vmin <= vmax); } - __syncthreads(); - non_nulls = cub::BlockReduce(storage.u32).Sum(non_nulls); - __syncthreads(); - null_count = cub::BlockReduce(storage.u32).Sum(null_count); - __syncthreads(); - if (has_minmax) { vsum = cub::BlockReduce(storage.i64).Sum(vsum); } - - if (!t) { - if (has_minmax) { - s->ck.min_value.i_val = vmin; - s->ck.max_value.i_val = vmax; - s->ck.sum.i_val = vsum; - } - s->ck.has_minmax = has_minmax; - // TODO: For now, don't set the sum flag with 64-bit values so we don't have to check for - // 64-bit sum overflow - s->ck.has_sum = (dtype <= dtype_int32 && has_minmax); - s->ck.non_nulls = non_nulls; - s->ck.null_count = null_count; - } -} - -/** - * @brief Merge statistics for floating-point columns - * - * @param s shared block state - * @param dtype data type - * @param ck_in pointer to first statistic chunk - * @param num_chunks number of statistic chunks to merge - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ mergeFloatColumnStats(merge_state_s *s, - const statistics_chunk *ck_in, - uint32_t num_chunks, - uint32_t t, - Storage &storage) -{ - double vmin = CUDART_INF; - double vmax = -CUDART_INF; - double vsum = 0; - uint32_t non_nulls = 0; - uint32_t null_count = 0; - __shared__ volatile bool has_minmax; - for (uint32_t i = t; i < num_chunks; i += block_size) { - const statistics_chunk *ck = &ck_in[i]; - if (ck->has_minmax) { - vmin = min(vmin, ck->min_value.fp_val); - vmax = max(vmax, ck->max_value.fp_val); - } - if (ck->has_sum) { vsum += ck->sum.fp_val; } - non_nulls += ck->non_nulls; - null_count += ck->null_count; - } - - vmin = cub::BlockReduce(storage.f64).Reduce(vmin, cub::Min()); - __syncthreads(); - vmax = cub::BlockReduce(storage.f64).Reduce(vmax, cub::Max()); - if (!t) { has_minmax = (vmin <= vmax); } - __syncthreads(); - non_nulls = cub::BlockReduce(storage.u32).Sum(non_nulls); - __syncthreads(); - null_count = cub::BlockReduce(storage.u32).Sum(null_count); - __syncthreads(); - if (has_minmax) { - vsum = cub::BlockReduce(storage.f64).Reduce(vsum, IgnoreNaNSum()); - } - - if (!t) { - if (has_minmax) { - s->ck.min_value.fp_val = (vmin != 0.0) ? vmin : CUDART_NEG_ZERO; - s->ck.max_value.fp_val = (vmax != 0.0) ? vmax : CUDART_ZERO; - s->ck.sum.fp_val = vsum; - } - s->ck.has_minmax = has_minmax; - s->ck.has_sum = has_minmax; // Implies sum is valid as well - s->ck.non_nulls = non_nulls; - s->ck.null_count = null_count; - } -} - -/** - * @brief Merge statistics for string columns - * - * @param s shared block state - * @param ck_in pointer to first statistic chunk - * @param num_chunks number of statistic chunks to merge - * @param t thread id - * @param storage temporary storage for reduction - */ -template -void __device__ mergeStringColumnStats(merge_state_s *s, - const statistics_chunk *ck_in, - uint32_t num_chunks, - uint32_t t, - Storage &storage) -{ - using block_reduce = cub::BlockReduce; - using string_reduce = cub::BlockReduce; - uint32_t len_sum = 0; - uint32_t non_nulls = 0; - uint32_t null_count = 0; - bool has_minmax = false; - - string_view minimum_value = string_view::max(); - string_view maximum_value = string_view::min(); - - for (uint32_t i = t; i < num_chunks; i += block_size) { - const statistics_chunk *ck = &ck_in[i]; - if (ck->has_minmax) { - has_minmax = true; - minimum_value = thrust::min(minimum_value, ck->min_value.str_val); - maximum_value = thrust::max(maximum_value, ck->max_value.str_val); - } - if (ck->has_sum) { len_sum += (uint32_t)ck->sum.i_val; } - non_nulls += ck->non_nulls; - null_count += ck->null_count; - } - minimum_value = string_reduce(storage.str).Reduce(minimum_value, cub::Min()); - __syncthreads(); - maximum_value = string_reduce(storage.str).Reduce(maximum_value, cub::Max()); - has_minmax = __syncthreads_or(has_minmax); - - non_nulls = block_reduce(storage.u32).Sum(non_nulls); - __syncthreads(); - null_count = block_reduce(storage.u32).Sum(null_count); - __syncthreads(); - if (has_minmax) { len_sum = block_reduce(storage.u32).Sum(len_sum); } - - if (!t) { - if (has_minmax) { - s->ck.min_value.str_val = minimum_value; - s->ck.max_value.str_val = maximum_value; - s->ck.sum.i_val = len_sum; - } - s->ck.has_minmax = has_minmax; - s->ck.has_sum = has_minmax; - s->ck.non_nulls = non_nulls; - s->ck.null_count = null_count; - } -} - -/** - * @brief Combine multiple statistics chunk together to form new statistics chunks - * - * blockDim {1024,1,1} - * - * @param chunks_out Destination statistic chunks - * @param chunks_in Source statistic chunks - * @param groups Statistic chunk grouping information - */ -template -__global__ void __launch_bounds__(block_size, 1) - gpuMergeColumnStatistics(statistics_chunk *chunks_out, - const statistics_chunk *chunks_in, - const statistics_merge_group *groups) -{ - __shared__ __align__(8) merge_state_s state_g; - __shared__ struct { - typename cub::BlockReduce::TempStorage u32; - typename cub::BlockReduce::TempStorage i64; - typename cub::BlockReduce::TempStorage f64; - typename cub::BlockReduce::TempStorage str; - } storage; - - merge_state_s *const s = &state_g; - uint32_t t = threadIdx.x; - statistics_dtype dtype; - - if (t < sizeof(statistics_merge_group) / sizeof(uint32_t)) { - reinterpret_cast(&s->group)[t] = - reinterpret_cast(&groups[blockIdx.x])[t]; - } - __syncthreads(); - if (t < sizeof(stats_column_desc) / sizeof(uint32_t)) { - reinterpret_cast(&s->col)[t] = reinterpret_cast(s->group.col)[t]; - } - __syncthreads(); - dtype = s->col.stats_dtype; - - if (dtype >= dtype_bool && dtype <= dtype_decimal64) { - mergeIntColumnStats( - s, dtype, chunks_in + s->group.start_chunk, s->group.num_chunks, t, storage); - } else if (dtype >= dtype_float32 && dtype <= dtype_float64) { - mergeFloatColumnStats(s, chunks_in + s->group.start_chunk, s->group.num_chunks, t, storage); - } else if (dtype == dtype_string) { - mergeStringColumnStats(s, chunks_in + s->group.start_chunk, s->group.num_chunks, t, storage); - } - - __syncthreads(); - if (t < sizeof(statistics_chunk) / sizeof(uint32_t)) { - reinterpret_cast(&chunks_out[blockIdx.x])[t] = - reinterpret_cast(&s->ck)[t]; - } -} - -/** - * @brief Launches kernel to gather column statistics - * - * @param[out] chunks Statistics results [num_chunks] - * @param[in] groups Statistics row groups [num_chunks] - * @param[in] num_chunks Number of chunks & rowgroups - * @param[in] stream CUDA stream to use, default 0 - */ -void GatherColumnStatistics(statistics_chunk *chunks, - const statistics_group *groups, - uint32_t num_chunks, - rmm::cuda_stream_view stream) -{ - gpuGatherColumnStatistics - <<>>(chunks, groups); -} - -/** - * @brief Launches kernel to merge column statistics - * - * @param[out] chunks_out Statistics results [num_chunks] - * @param[out] chunks_in Input statistics - * @param[in] groups Statistics groups [num_chunks] - * @param[in] num_chunks Number of chunks & groups - * @param[in] stream CUDA stream to use, default 0 - */ -void MergeColumnStatistics(statistics_chunk *chunks_out, - const statistics_chunk *chunks_in, - const statistics_merge_group *groups, - uint32_t num_chunks, - rmm::cuda_stream_view stream) -{ - gpuMergeColumnStatistics - <<>>(chunks_out, chunks_in, groups); -} - -} // namespace io -} // namespace cudf diff --git a/cpp/src/io/statistics/conversion_type_select.cuh b/cpp/src/io/statistics/conversion_type_select.cuh new file mode 100644 index 00000000000..225377bfc4b --- /dev/null +++ b/cpp/src/io/statistics/conversion_type_select.cuh @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file conversion_type_select.cuh + * @brief Utility classes for timestamp and duration conversion for PARQUET and ORC + */ + +#pragma once + +#include +#include +#include + +namespace cudf { +namespace io { +namespace detail { + +template +class DetectInnerIteration; + +template +class DetectInnerIteration> { + public: + static constexpr bool is_duplicate = + std::is_same_v>::type, + typename std::tuple_element<0, std::tuple>::type>; +}; + +template +class DetectInnerIteration> { + public: + static constexpr bool is_duplicate = + std::is_same_v>::type, + typename std::tuple_element>::type> || + DetectInnerIteration>::is_duplicate; +}; + +template +class DetectIteration; + +template +class DetectIteration<0, std::tuple> { + public: + static constexpr bool is_duplicate = false; +}; + +template +class DetectIteration> { + public: + static constexpr bool is_duplicate = + DetectInnerIteration>::is_duplicate || + DetectIteration>::is_duplicate; +}; + +template +class Detect; + +/** + * @brief Utility class to detect multiple occurences of a type in the first element of pairs in a + * tuple For eg. with the following tuple : + * + * using conversion_types = + * std::tuple< + * std::pair, + * std::pair, + * std::pair, + * std::pair, + * std::pair, + * std::pair>; + * + * Detect::is_duplicate will evaluate to true at compile time. + * Here std::pair, std::pair and std::pair are treated as duplicates + * and std::pair and std::pair> are treated as duplicates. + * + * @tparam T... Parameter pack of pairs of types + */ +template +class Detect> { + public: + static constexpr bool is_duplicate = + DetectIteration<(sizeof...(T) - 1), std::tuple>::is_duplicate; +}; + +template +class ConversionTypeSelect; + +template +class ConversionTypeSelect> { + public: + template + using type = std::conditional_t::type>, + typename std::tuple_element<1, I0>::type, + T>; +}; + +/** + * @brief Utility to select between types based on an input type + * + * using Conversion = std::tuple< + * std::pair, + * std::pair, + * std::pair, + * std::pair> + * + * using type = ConversionTypeSelect::type + * Here type will resolve to cudf::duration_us + * If the type passed does not match any entries the type is returned as it is + * This utility takes advantage of Detect class to reject any tuple with duplicate first + * entries at compile time + * + * @tparam T... Parameter pack of pairs of types + */ +template +class ConversionTypeSelect> { + public: + template + using type = + std::conditional_t::type>, + typename std::tuple_element<1, I0>::type, + typename ConversionTypeSelect>::template type>; + + static_assert(not Detect>::is_duplicate, + "Type tuple has duplicate first entries"); +}; + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/orc_column_statistics.cu b/cpp/src/io/statistics/orc_column_statistics.cu new file mode 100644 index 00000000000..ad8a05a56f5 --- /dev/null +++ b/cpp/src/io/statistics/orc_column_statistics.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file orc_column_statistics.cu + * @brief Template specialization for ORC statistics calls + */ + +#include "column_statistics.cuh" + +namespace cudf { +namespace io { +namespace detail { + +template <> +void merge_group_statistics(statistics_chunk *chunks_out, + const statistics_chunk *chunks_in, + const statistics_merge_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream); +template <> +void calculate_group_statistics(statistics_chunk *chunks, + const statistics_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream); + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/parquet_column_statistics.cu b/cpp/src/io/statistics/parquet_column_statistics.cu new file mode 100644 index 00000000000..ad067cd4aad --- /dev/null +++ b/cpp/src/io/statistics/parquet_column_statistics.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file parquet_column_statistics.cu + * @brief Template specialization for PARQUET statistics calls + */ + +#include "column_statistics.cuh" + +namespace cudf { +namespace io { +namespace detail { + +template <> +void merge_group_statistics(statistics_chunk *chunks_out, + const statistics_chunk *chunks_in, + const statistics_merge_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream); +template <> +void calculate_group_statistics(statistics_chunk *chunks, + const statistics_group *groups, + uint32_t num_chunks, + rmm::cuda_stream_view stream); + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/column_stats.h b/cpp/src/io/statistics/statistics.cuh similarity index 68% rename from cpp/src/io/statistics/column_stats.h rename to cpp/src/io/statistics/statistics.cuh index d7895de50ce..f7bf6e407c1 100644 --- a/cpp/src/io/statistics/column_stats.h +++ b/cpp/src/io/statistics/statistics.cuh @@ -13,6 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +/** + * @file statistics.cuh + * @brief Common structures and utility functions for statistics + */ + #pragma once #include @@ -69,16 +75,17 @@ struct string_stats { { return string_view(ptr, static_cast(length)); } + __host__ __device__ __forceinline__ operator string_view() + { + return string_view(ptr, static_cast(length)); + } }; union statistics_val { string_stats str_val; //!< string columns double fp_val; //!< float columns int64_t i_val; //!< integer columns - struct { - uint64_t lo64; - int64_t hi64; - } i128_val; //!< decimal128 columns + uint64_t u_val; //!< unsigned integer columns }; struct statistics_chunk { @@ -86,12 +93,9 @@ struct statistics_chunk { uint32_t null_count; //!< number of null values in chunk statistics_val min_value; //!< minimum value in chunk statistics_val max_value; //!< maximum value in chunk - union { - double fp_val; //!< Sum for fp types - int64_t i_val; //!< Sum for integer types or string lengths - } sum; - uint8_t has_minmax; //!< Nonzero if min_value and max_values are valid - uint8_t has_sum; //!< Nonzero if sum is valid + statistics_val sum; //!< sum of chunk + uint8_t has_minmax; //!< Nonzero if min_value and max_values are valid + uint8_t has_sum; //!< Nonzero if sum is valid }; struct statistics_group { @@ -106,33 +110,5 @@ struct statistics_merge_group { uint32_t num_chunks; //!< Number of chunks in group }; -/** - * @brief Launches kernel to gather column statistics - * - * @param[out] chunks Statistics results [num_chunks] - * @param[in] groups Statistics row groups [num_chunks] - * @param[in] num_chunks Number of chunks & rowgroups - * @param[in] stream CUDA stream to use, default 0 - */ -void GatherColumnStatistics(statistics_chunk *chunks, - const statistics_group *groups, - uint32_t num_chunks, - rmm::cuda_stream_view stream); - -/** - * @brief Launches kernel to merge column statistics - * - * @param[out] chunks_out Statistics results [num_chunks] - * @param[out] chunks_in Input statistics - * @param[in] groups Statistics groups [num_chunks] - * @param[in] num_chunks Number of chunks & groups - * @param[in] stream CUDA stream to use, default 0 - */ -void MergeColumnStatistics(statistics_chunk *chunks_out, - const statistics_chunk *chunks_in, - const statistics_merge_group *groups, - uint32_t num_chunks, - rmm::cuda_stream_view stream); - } // namespace io } // namespace cudf diff --git a/cpp/src/io/statistics/statistics_type_identification.cuh b/cpp/src/io/statistics/statistics_type_identification.cuh new file mode 100644 index 00000000000..84399a307a5 --- /dev/null +++ b/cpp/src/io/statistics/statistics_type_identification.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file statistics_type_identification.cuh + * @brief Utility classes to identify extrema, aggregate and conversion types for ORC and PARQUET + */ + +#pragma once + +#include + +#include + +#include + +#include + +#include + +#include "conversion_type_select.cuh" + +#include + +namespace cudf { +namespace io { +namespace detail { + +enum class io_file_format { ORC, PARQUET }; + +template +struct conversion_map; + +// Every timestamp or duration type is converted to milliseconds in ORC statistics +template <> +struct conversion_map { + using types = std::tuple, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair>; +}; + +// In Parquet timestamps and durations with second resoluion are converted to +// milliseconds. Timestamps and durations with nanosecond resoluion are +// converted to microseconds. +template <> +struct conversion_map { + using types = std::tuple, + std::pair, + std::pair, + std::pair>; +}; + +/** + * @brief Utility class to help conversion of timestamps and durations to their + * representation type + * + * @tparam conversion A conversion_map structure + */ +template +class type_conversion { + using type_selector = ConversionTypeSelect; + + public: + template + using type = typename type_selector::template type; + + template + static constexpr __device__ typename type_selector::template type convert(const T& elem) + { + using Type = typename type_selector::template type; + if constexpr (cudf::is_duration()) { + return cuda::std::chrono::duration_cast(elem); + } else if constexpr (cudf::is_timestamp()) { + using Duration = typename Type::duration; + return cuda::std::chrono::time_point_cast(elem); + } else { + return elem; + } + return Type{}; + } +}; + +template +struct dependent_false : std::false_type { +}; + +/** + * @brief Utility class to convert a leaf column element into its extrema type + * + * @tparam T Column type + */ +template +class extrema_type { + private: + using integral_extrema_type = typename std::conditional_t, int64_t, uint64_t>; + + using arithmetic_extrema_type = + typename std::conditional_t, integral_extrema_type, double>; + + using non_arithmetic_extrema_type = typename std::conditional_t< + cudf::is_fixed_point() or cudf::is_duration() or cudf::is_timestamp(), + int64_t, + typename std::conditional_t, string_view, void>>; + + // unsigned int/bool -> uint64_t + // signed int -> int64_t + // float/double -> double + // decimal32/64 -> int64_t + // duration_[T] -> int64_t + // string_view -> string_view + // timestamp_[T] -> int64_t + + public: + // Does type T have an extrema? + static constexpr bool is_supported = std::is_arithmetic_v or std::is_same_v or + cudf::is_duration() or cudf::is_timestamp() or + cudf::is_fixed_point(); + + using type = typename std:: + conditional_t, arithmetic_extrema_type, non_arithmetic_extrema_type>; + + /** + * @brief Function that converts an element of a leaf column into its extrema type + */ + __device__ static type convert(const T& val) + { + if constexpr (std::is_arithmetic_v or std::is_same_v) { + return val; + } else if constexpr (cudf::is_fixed_point()) { + return val.value(); + } else if constexpr (cudf::is_duration()) { + return val.count(); + } else if constexpr (cudf::is_timestamp()) { + return val.time_since_epoch().count(); + } else { + static_assert(dependent_false::value, "aggregation_type does not exist"); + } + return type{}; + } +}; + +/** + * @brief Utility class to convert a leaf column element into its aggregate type + * + * @tparam T Column type + */ +template +class aggregation_type { + private: + using integral_aggregation_type = + typename std::conditional_t, int64_t, uint64_t>; + + using arithmetic_aggregation_type = + typename std::conditional_t, integral_aggregation_type, double>; + + using non_arithmetic_aggregation_type = + typename std::conditional_t() or cudf::is_duration() or + cudf::is_timestamp() // To be disabled with static_assert + or std::is_same_v, + int64_t, + void>; + + // unsigned int/bool -> uint64_t + // signed int -> int64_t + // float/double -> double + // decimal32/64 -> int64_t + // duration_[T] -> int64_t + // string_view -> int64_t + // NOTE : timestamps do not have an aggregation type + + public: + // Does type T aggregate? + static constexpr bool is_supported = std::is_arithmetic_v or std::is_same_v or + cudf::is_duration() or cudf::is_fixed_point(); + + using type = typename std::conditional_t, + arithmetic_aggregation_type, + non_arithmetic_aggregation_type>; + + /** + * @brief Function that converts an element of a leaf column into its aggregate type + */ + __device__ static type convert(const T& val) + { + if constexpr (std::is_same_v) { + return val.size_bytes(); + } else if constexpr (std::is_integral_v) { + return val; + } else if constexpr (std::is_floating_point_v) { + return isnan(val) ? 0 : val; + } else if constexpr (cudf::is_fixed_point()) { + return val.value(); + } else if constexpr (cudf::is_duration()) { + return val.count(); + } else if constexpr (cudf::is_timestamp()) { + static_assert(dependent_false::value, "aggregation_type for timestamps do not exist"); + } else { + static_assert(dependent_false::value, "aggregation_type for supplied type do not exist"); + } + return type{}; + } +}; + +template +__inline__ __device__ constexpr T minimum_identity() +{ + if constexpr (std::is_same_v) { return string_view::max(); } + return cuda::std::numeric_limits::max(); +} + +template +__inline__ __device__ constexpr T maximum_identity() +{ + if constexpr (std::is_same_v) { return string_view::min(); } + return cuda::std::numeric_limits::lowest(); +} + +/** + * @brief Utility class to identify whether a type T is aggregated or ignored + * for ORC or PARQUET + * + * @tparam T Leaf column type + * @tparam IO File format for which statistics calculation is being done + */ +template +class statistics_type_category { + public: + // Types that calculate the sum of elements encountered + static constexpr bool is_aggregated = + (IO == io_file_format::PARQUET) ? false : aggregation_type::is_supported; + + // Types for which sum does not make sense + static constexpr bool is_not_aggregated = + (IO == io_file_format::PARQUET) ? aggregation_type::is_supported or cudf::is_timestamp() + : cudf::is_timestamp(); + + // Do not calculate statistics for any other type + static constexpr bool is_ignored = not(is_aggregated or is_not_aggregated); +}; + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/temp_storage_wrapper.cuh b/cpp/src/io/statistics/temp_storage_wrapper.cuh new file mode 100644 index 00000000000..7a36c873ba6 --- /dev/null +++ b/cpp/src/io/statistics/temp_storage_wrapper.cuh @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file temp_storage_wrapper.cuh + * @brief Temporary storage for cub calls and helper wrapper class + */ + +#pragma once + +#include +#include +#include +#include + +#include "statistics.cuh" + +#include + +namespace cudf { +namespace io { +namespace detail { + +template +using cub_temp_storage = typename cub::BlockReduce::TempStorage; + +#define MEMBER_NAME(TYPE) TYPE##_stats + +#define DECLARE_MEMBER(TYPE) cub_temp_storage MEMBER_NAME(TYPE); + +/** + * @brief Templated union to hold temporary storage to be used by cub reduce + * calls + * + * @tparam block_size Dimension of the block + */ +template +union block_reduce_storage { + DECLARE_MEMBER(bool) + DECLARE_MEMBER(int8_t) + DECLARE_MEMBER(int16_t) + DECLARE_MEMBER(int32_t) + DECLARE_MEMBER(int64_t) + DECLARE_MEMBER(uint8_t) + DECLARE_MEMBER(uint16_t) + DECLARE_MEMBER(uint32_t) + DECLARE_MEMBER(uint64_t) + DECLARE_MEMBER(float) + DECLARE_MEMBER(double) + DECLARE_MEMBER(string_view) +}; + +#define STORAGE_WRAPPER_GET(TYPE) \ + template \ + __device__ std::enable_if_t, cub_temp_storage&> get() \ + { \ + return storage.MEMBER_NAME(TYPE); \ + } + +/** + * @brief Templated wrapper for block_reduce_storage to return member reference based on requested + * type + * + * @tparam block_size Dimension of the block + */ +template +struct storage_wrapper { + block_reduce_storage& storage; + __device__ storage_wrapper(block_reduce_storage& _temp_storage) + : storage(_temp_storage) + { + } + + STORAGE_WRAPPER_GET(bool); + STORAGE_WRAPPER_GET(int8_t); + STORAGE_WRAPPER_GET(int16_t); + STORAGE_WRAPPER_GET(int32_t); + STORAGE_WRAPPER_GET(int64_t); + STORAGE_WRAPPER_GET(uint8_t); + STORAGE_WRAPPER_GET(uint16_t); + STORAGE_WRAPPER_GET(uint32_t); + STORAGE_WRAPPER_GET(uint64_t); + STORAGE_WRAPPER_GET(float); + STORAGE_WRAPPER_GET(double); + STORAGE_WRAPPER_GET(string_view); +}; + +#undef DECLARE_MEMBER +#undef MEMBER_NAME + +} // namespace detail +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/statistics/typed_statistics_chunk.cuh b/cpp/src/io/statistics/typed_statistics_chunk.cuh new file mode 100644 index 00000000000..20b7fdc927b --- /dev/null +++ b/cpp/src/io/statistics/typed_statistics_chunk.cuh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file typed_statistics_chunk + * @brief Templated wrapper to generalize statistics chunk reduction and aggregation + * across different leaf column types + */ + +#pragma once + +#include "statistics.cuh" +#include "statistics_type_identification.cuh" +#include "temp_storage_wrapper.cuh" + +#include +#include + +#include + +namespace cudf { +namespace io { + +/** + * @brief Class used to get reference to members of unions related to statistics calculations + */ +class union_member { + template + using reference_type = std::conditional_t, const V&, V&>; + + public: + template + using type = std::conditional_t, string_view>, + reference_type, + reference_type>; + + template + __device__ static std::enable_if_t and std::is_unsigned_v, type> + get(U& val) + { + return val.u_val; + } + + template + __device__ static std::enable_if_t and std::is_signed_v, type> get( + U& val) + { + return val.i_val; + } + + template + __device__ static std::enable_if_t, type> get(U& val) + { + return val.fp_val; + } + + template + __device__ static std::enable_if_t, type> get(U& val) + { + return val.str_val; + } +}; + +/** + * @brief Templated structure used for merging and gathering of statistics chunks + * + * This uses the reduce function to compute the minimum, maximum and aggregate + * values simultaneously. + * + * @tparam T The input type associated with the chunk + * @tparam is_aggregation_supported Set to true if input type is meant to be aggregated + */ +template +struct typed_statistics_chunk { +}; + +template +struct typed_statistics_chunk { + using E = typename detail::extrema_type::type; + using A = typename detail::aggregation_type::type; + + uint32_t num_rows; //!< number of non-null values in chunk + uint32_t non_nulls; //!< number of non-null values in chunk + uint32_t null_count; //!< number of null values in chunk + + E minimum_value; + E maximum_value; + A aggregate; + + uint8_t has_minmax; //!< Nonzero if min_value and max_values are valid + uint8_t has_sum; //!< Nonzero if sum is valid + + __device__ typed_statistics_chunk(const uint32_t _num_rows = 0) + : num_rows(_num_rows), + non_nulls(0), + null_count(0), + minimum_value(detail::minimum_identity()), + maximum_value(detail::maximum_identity()), + aggregate(0), + has_minmax(false), + has_sum(false) // Set to true when storing + { + } + + __device__ void reduce(const T& elem) + { + non_nulls++; + minimum_value = thrust::min(minimum_value, detail::extrema_type::convert(elem)); + maximum_value = thrust::max(maximum_value, detail::extrema_type::convert(elem)); + aggregate += detail::aggregation_type::convert(elem); + has_minmax = true; + } + + __device__ void reduce(const statistics_chunk& chunk) + { + if (chunk.has_minmax) { + minimum_value = thrust::min(minimum_value, union_member::get(chunk.min_value)); + maximum_value = thrust::max(maximum_value, union_member::get(chunk.max_value)); + } + if (chunk.has_sum) { + aggregate += detail::aggregation_type::convert(union_member::get(chunk.sum)); + } + non_nulls += chunk.non_nulls; + null_count += chunk.null_count; + num_rows += (chunk.non_nulls + chunk.null_count); + } +}; + +template +struct typed_statistics_chunk { + using E = typename detail::extrema_type::type; + + uint32_t num_rows; //!< number of non-null values in chunk + uint32_t non_nulls; //!< number of non-null values in chunk + uint32_t null_count; //!< number of null values in chunk + + E minimum_value; + E maximum_value; + + uint8_t has_minmax; //!< Nonzero if min_value and max_values are valid + uint8_t has_sum; //!< Nonzero if sum is valid + + __device__ typed_statistics_chunk(const uint32_t _num_rows = 0) + : num_rows(_num_rows), + non_nulls(0), + null_count(0), + minimum_value(detail::minimum_identity()), + maximum_value(detail::maximum_identity()), + has_minmax(false), + has_sum(false) // Set to true when storing + { + } + + __device__ void reduce(const T& elem) + { + non_nulls++; + minimum_value = thrust::min(minimum_value, detail::extrema_type::convert(elem)); + maximum_value = thrust::max(maximum_value, detail::extrema_type::convert(elem)); + has_minmax = true; + } + + __device__ void reduce(const statistics_chunk& chunk) + { + if (chunk.has_minmax) { + minimum_value = thrust::min(minimum_value, union_member::get(chunk.min_value)); + maximum_value = thrust::max(maximum_value, union_member::get(chunk.max_value)); + } + non_nulls += chunk.non_nulls; + null_count += chunk.null_count; + num_rows += (chunk.non_nulls + chunk.null_count); + } +}; + +/** + * @brief Function to reduce members of a typed_statistics_chunk across a thread block + * + * @tparam T Type associated with typed_statistics_chunk + * @tparam block_size Dimension of the thread block + * @param chunk The input typed_statistics_chunk + * @param storage Temporary storage to be used by cub calls + */ +template +__inline__ __device__ typed_statistics_chunk block_reduce( + typed_statistics_chunk& chunk, detail::storage_wrapper& storage) +{ + typed_statistics_chunk output_chunk = chunk; + + using E = typename detail::extrema_type::type; + using extrema_reduce = cub::BlockReduce; + using count_reduce = cub::BlockReduce; + output_chunk.minimum_value = + extrema_reduce(storage.template get()).Reduce(output_chunk.minimum_value, cub::Min()); + __syncthreads(); + output_chunk.maximum_value = + extrema_reduce(storage.template get()).Reduce(output_chunk.maximum_value, cub::Max()); + __syncthreads(); + output_chunk.non_nulls = + count_reduce(storage.template get()).Sum(output_chunk.non_nulls); + __syncthreads(); + output_chunk.null_count = + count_reduce(storage.template get()).Sum(output_chunk.null_count); + __syncthreads(); + output_chunk.has_minmax = __syncthreads_or(output_chunk.has_minmax); + + // FIXME : Is another syncthreads needed here? + if constexpr (is_aggregated) { + if (output_chunk.has_minmax) { + using A = typename detail::aggregation_type::type; + using aggregate_reduce = cub::BlockReduce; + output_chunk.aggregate = + aggregate_reduce(storage.template get()).Sum(output_chunk.aggregate); + } + } + return output_chunk; +} + +/** + * @brief Function to convert typed_statistics_chunk into statistics_chunk + * + * @tparam T Type associated with typed_statistics_chunk + * @param chunk The input typed_statistics_chunk + */ +template +__inline__ __device__ statistics_chunk +get_untyped_chunk(const typed_statistics_chunk& chunk) +{ + statistics_chunk stat; + stat.non_nulls = chunk.non_nulls; + stat.null_count = chunk.num_rows - chunk.non_nulls; + stat.has_minmax = chunk.has_minmax; + stat.has_sum = + chunk.has_minmax; // If a valid input was encountered we assume that the sum is valid + if (chunk.has_minmax) { + using E = typename detail::extrema_type::type; + if constexpr (std::is_floating_point_v) { + union_member::get(stat.min_value) = + (chunk.minimum_value != 0.0) ? chunk.minimum_value : CUDART_NEG_ZERO; + union_member::get(stat.max_value) = + (chunk.maximum_value != 0.0) ? chunk.maximum_value : CUDART_ZERO; + } else { + union_member::get(stat.min_value) = chunk.minimum_value; + union_member::get(stat.max_value) = chunk.maximum_value; + } + if constexpr (is_aggregated) { + using A = typename detail::aggregation_type::type; + union_member::get(stat.sum) = chunk.aggregate; + } + } + return stat; +} + +} // namespace io +} // namespace cudf diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 4781ff995b0..54bf17e4c2b 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -19,7 +19,11 @@ import cudf from cudf.io.parquet import ParquetWriter, merge_parquet_filemetadata from cudf.tests import dataset_generator as dg -from cudf.tests.utils import assert_eq, assert_exceptions_equal +from cudf.tests.utils import ( + TIMEDELTA_TYPES, + assert_eq, + assert_exceptions_equal, +) @pytest.fixture(scope="module") @@ -1782,6 +1786,9 @@ def test_parquet_writer_statistics(tmpdir, pdf): if "col_category" in pdf.columns: pdf = pdf.drop(columns=["col_category", "col_bool"]) + for t in TIMEDELTA_TYPES: + pdf["col_" + t] = pd.Series(np.arange(len(pdf.index))).astype(t) + gdf = cudf.from_pandas(pdf) gdf.to_parquet(file_path, index=False)