From b4fd77b30311f3b1de39cac22423f2c3a32ec72d Mon Sep 17 00:00:00 2001 From: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:20:42 -0500 Subject: [PATCH] Centralize chunked reading code in the parquet reader to reader_impl_chunking.cu (#14262) As a precursor to further chunked reader work, this PR centralizes chunk-related code (mostly from the `reader::impl` class) into `reader_impl_chunking.cu` and `reader_impl_chunking.hpp`. Also cleans up some variable naming and locations in `reader::impl` and the `file_intermediate_data` and `pass_intermediate_data classes`. Authors: - https://github.com/nvdbaranec Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Robert Maynard (https://github.com/robertmaynard) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/14262 --- cpp/CMakeLists.txt | 1 + cpp/src/io/parquet/parquet_gpu.hpp | 73 --- cpp/src/io/parquet/reader_impl.cpp | 12 +- cpp/src/io/parquet/reader_impl.hpp | 49 +- cpp/src/io/parquet/reader_impl_chunking.cu | 598 +++++++++++++++++++ cpp/src/io/parquet/reader_impl_chunking.hpp | 87 +++ cpp/src/io/parquet/reader_impl_helpers.hpp | 17 + cpp/src/io/parquet/reader_impl_preprocess.cu | 558 +---------------- cpp/src/io/utilities/column_buffer.cpp | 10 +- 9 files changed, 751 insertions(+), 654 deletions(-) create mode 100644 cpp/src/io/parquet/reader_impl_chunking.cu create mode 100644 cpp/src/io/parquet/reader_impl_chunking.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 000f80065ab..f8b9762f1d4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -401,6 +401,7 @@ add_library( src/io/parquet/predicate_pushdown.cpp src/io/parquet/reader.cpp src/io/parquet/reader_impl.cpp + src/io/parquet/reader_impl_chunking.cu src/io/parquet/reader_impl_helpers.cpp src/io/parquet/reader_impl_preprocess.cu src/io/parquet/writer_impl.cu diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 767668cc65e..6a93fec0c46 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -318,79 +318,6 @@ struct ColumnChunkDesc { int32_t src_col_schema{}; // my schema index in the file }; -/** - * @brief The row_group_info class - */ -struct row_group_info { - size_type index; // row group index within a file. aggregate_reader_metadata::get_row_group() is - // called with index and source_index - size_t start_row; - size_type source_index; // file index. - - row_group_info() = default; - - row_group_info(size_type index, size_t start_row, size_type source_index) - : index{index}, start_row{start_row}, source_index{source_index} - { - } -}; - -/** - * @brief Struct to store file-level data that remains constant for - * all passes/chunks for the file. - */ -struct file_intermediate_data { - // all row groups to read - std::vector row_groups{}; - - // all chunks from the selected row groups. We may end up reading these chunks progressively - // instead of all at once - std::vector chunks{}; - - // skip_rows/num_rows values for the entire file. these need to be adjusted per-pass because we - // may not be visiting every row group that contains these bounds - size_t global_skip_rows; - size_t global_num_rows; -}; - -/** - * @brief Structs to identify the reading row range for each chunk of rows in chunked reading. - */ -struct chunk_read_info { - size_t skip_rows; - size_t num_rows; -}; - -/** - * @brief Struct to store pass-level data that remains constant for a single pass. - */ -struct pass_intermediate_data { - std::vector> raw_page_data; - rmm::device_buffer decomp_page_data; - - // rowgroup, chunk and page information for the current pass. - std::vector row_groups{}; - cudf::detail::hostdevice_vector chunks{}; - cudf::detail::hostdevice_vector pages_info{}; - cudf::detail::hostdevice_vector page_nesting_info{}; - cudf::detail::hostdevice_vector page_nesting_decode_info{}; - - rmm::device_uvector page_keys{0, rmm::cuda_stream_default}; - rmm::device_uvector page_index{0, rmm::cuda_stream_default}; - rmm::device_uvector str_dict_index{0, rmm::cuda_stream_default}; - - std::vector output_chunk_read_info; - std::size_t current_output_chunk{0}; - - rmm::device_buffer level_decode_data{}; - int level_type_size{0}; - - // skip_rows and num_rows values for this particular pass. these may be adjusted values from the - // global values stored in file_intermediate_data. - size_t skip_rows; - size_t num_rows; -}; - /** * @brief Struct describing an encoder column */ diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 26ec83d5946..db81222157a 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -349,14 +349,14 @@ void reader::impl::prepare_data(int64_t skip_rows, not _input_columns.empty()) { // fills in chunk information without physically loading or decompressing // the associated data - load_global_chunk_info(); + create_global_chunk_info(); // compute schedule of input reads. Each rowgroup contains 1 chunk per column. For now // we will read an entire row group at a time. However, it is possible to do // sub-rowgroup reads if we made some estimates on individual chunk sizes (tricky) and // changed the high level structure such that we weren't always reading an entire table's // worth of columns at once. - compute_input_pass_row_group_info(); + compute_input_passes(); } _file_preprocessed = true; @@ -364,7 +364,7 @@ void reader::impl::prepare_data(int64_t skip_rows, // if we have to start a new pass, do that now if (!_pass_preprocessed) { - auto const num_passes = _input_pass_row_group_offsets.size() - 1; + auto const num_passes = _file_itm_data.input_pass_row_group_offsets.size() - 1; // always create the pass struct, even if we end up with no passes. // this will also cause the previous pass information to be deleted @@ -373,7 +373,7 @@ void reader::impl::prepare_data(int64_t skip_rows, if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && not _input_columns.empty() && _current_input_pass < num_passes) { // setup the pass_intermediate_info for this pass. - setup_pass(); + setup_next_pass(); load_and_decompress_data(); preprocess_pages(uses_custom_row_bounds, _output_chunk_read_limit); @@ -541,8 +541,8 @@ bool reader::impl::has_next() {} /*row_group_indices, empty means read all row groups*/, std::nullopt /*filter*/); - auto const num_input_passes = - _input_pass_row_group_offsets.size() == 0 ? 0 : _input_pass_row_group_offsets.size() - 1; + size_t const num_input_passes = std::max( + int64_t{0}, static_cast(_file_itm_data.input_pass_row_group_offsets.size()) - 1); return (_pass_itm_data->current_output_chunk < _pass_itm_data->output_chunk_read_info.size()) || (_current_input_pass < num_input_passes); } diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index 6003b931b04..cea4ba35606 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -22,6 +22,7 @@ #pragma once #include "parquet_gpu.hpp" +#include "reader_impl_chunking.hpp" #include "reader_impl_helpers.hpp" #include @@ -136,10 +137,6 @@ class reader::impl { host_span const> row_group_indices, std::optional> filter); - void load_global_chunk_info(); - void compute_input_pass_row_group_info(); - void setup_pass(); - /** * @brief Create chunk information and start file reads * @@ -250,6 +247,31 @@ class reader::impl { */ void decode_page_data(size_t skip_rows, size_t num_rows); + /** + * @brief Creates file-wide parquet chunk information. + * + * Creates information about all chunks in the file, storing it in + * the file-wide _file_itm_data structure. + */ + void create_global_chunk_info(); + + /** + * @brief Computes all of the passes we will perform over the file. + */ + void compute_input_passes(); + + /** + * @brief Close out the existing pass (if any) and prepare for the next pass. + */ + void setup_next_pass(); + + /** + * @brief Given a set of pages that have had their sizes computed by nesting level and + * a limit on total read size, generate a set of {skip_rows, num_rows} pairs representing + * a set of reads that will generate output columns of total size <= `chunk_read_limit` bytes. + */ + void compute_splits_for_pass(); + private: rmm::cuda_stream_view _stream; rmm::mr::device_memory_resource* _mr = nullptr; @@ -278,7 +300,7 @@ class reader::impl { // chunked reading happens in 2 parts: // - // At the top level there is the "pass" in which we try and limit the + // At the top level, the entire file is divided up into "passes" omn which we try and limit the // total amount of temporary memory (compressed data, decompressed data) in use // via _input_pass_read_limit. // @@ -286,19 +308,16 @@ class reader::impl { // byte size is controlled by _output_chunk_read_limit. file_intermediate_data _file_itm_data; + bool _file_preprocessed{false}; + std::unique_ptr _pass_itm_data; + bool _pass_preprocessed{false}; - // an array of offsets into _file_itm_data::global_chunks. Each pair of offsets represents - // the start/end of the chunks to be loaded for a given pass. - std::vector _input_pass_row_group_offsets{}; - std::vector _input_pass_row_count{}; - std::size_t _current_input_pass{0}; - std::size_t _chunk_count{0}; + std::size_t _output_chunk_read_limit{0}; // output chunk size limit in bytes + std::size_t _input_pass_read_limit{0}; // input pass memory usage limit in bytes - std::size_t _output_chunk_read_limit{0}; - std::size_t _input_pass_read_limit{0}; - bool _pass_preprocessed{false}; - bool _file_preprocessed{false}; + std::size_t _current_input_pass{0}; // current input pass index + std::size_t _chunk_count{0}; // how many output chunks we have produced }; } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu new file mode 100644 index 00000000000..ad52a7dfcc1 --- /dev/null +++ b/cpp/src/io/parquet/reader_impl_chunking.cu @@ -0,0 +1,598 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "reader_impl.hpp" +#include "reader_impl_chunking.hpp" + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +namespace cudf::io::parquet::detail { + +namespace { + +struct cumulative_row_info { + size_t row_count; // cumulative row count + size_t size_bytes; // cumulative size in bytes + int key; // schema index +}; + +#if defined(CHUNKING_DEBUG) +void print_cumulative_page_info(cudf::detail::hostdevice_vector& pages, + rmm::device_uvector const& page_index, + rmm::device_uvector const& c_info, + rmm::cuda_stream_view stream) +{ + pages.device_to_host_sync(stream); + + printf("------------\nCumulative sizes by page\n"); + + std::vector schemas(pages.size()); + std::vector h_page_index(pages.size()); + CUDF_CUDA_TRY(cudaMemcpy( + h_page_index.data(), page_index.data(), sizeof(int) * pages.size(), cudaMemcpyDefault)); + std::vector h_cinfo(pages.size()); + CUDF_CUDA_TRY(cudaMemcpy( + h_cinfo.data(), c_info.data(), sizeof(cumulative_row_info) * pages.size(), cudaMemcpyDefault)); + auto schema_iter = cudf::detail::make_counting_transform_iterator( + 0, [&](size_type i) { return pages[h_page_index[i]].src_col_schema; }); + thrust::copy(thrust::seq, schema_iter, schema_iter + pages.size(), schemas.begin()); + auto last = thrust::unique(thrust::seq, schemas.begin(), schemas.end()); + schemas.resize(last - schemas.begin()); + printf("Num schemas: %lu\n", schemas.size()); + + for (size_t idx = 0; idx < schemas.size(); idx++) { + printf("Schema %d\n", schemas[idx]); + for (size_t pidx = 0; pidx < pages.size(); pidx++) { + auto const& page = pages[h_page_index[pidx]]; + if (page.flags & PAGEINFO_FLAGS_DICTIONARY || page.src_col_schema != schemas[idx]) { + continue; + } + printf("\tP: {%lu, %lu}\n", h_cinfo[pidx].row_count, h_cinfo[pidx].size_bytes); + } + } +} + +void print_cumulative_row_info(host_span sizes, + std::string const& label, + std::optional> splits = std::nullopt) +{ + if (splits.has_value()) { + printf("------------\nSplits\n"); + for (size_t idx = 0; idx < splits->size(); idx++) { + printf("{%lu, %lu}\n", splits.value()[idx].skip_rows, splits.value()[idx].num_rows); + } + } + + printf("------------\nCumulative sizes %s\n", label.c_str()); + for (size_t idx = 0; idx < sizes.size(); idx++) { + printf("{%lu, %lu, %d}", sizes[idx].row_count, sizes[idx].size_bytes, sizes[idx].key); + if (splits.has_value()) { + // if we have a split at this row count and this is the last instance of this row count + auto start = thrust::make_transform_iterator( + splits->begin(), [](chunk_read_info const& i) { return i.skip_rows; }); + auto end = start + splits->size(); + auto split = std::find(start, end, sizes[idx].row_count); + auto const split_index = [&]() -> int { + if (split != end && + ((idx == sizes.size() - 1) || (sizes[idx + 1].row_count > sizes[idx].row_count))) { + return static_cast(std::distance(start, split)); + } + return idx == 0 ? 0 : -1; + }(); + if (split_index >= 0) { + printf(" <-- split {%lu, %lu}", + splits.value()[split_index].skip_rows, + splits.value()[split_index].num_rows); + } + } + printf("\n"); + } +} +#endif // CHUNKING_DEBUG + +/** + * @brief Functor which reduces two cumulative_row_info structs of the same key. + */ +struct cumulative_row_sum { + cumulative_row_info operator() + __device__(cumulative_row_info const& a, cumulative_row_info const& b) const + { + return cumulative_row_info{a.row_count + b.row_count, a.size_bytes + b.size_bytes, a.key}; + } +}; + +/** + * @brief Functor which computes the total data size for a given type of cudf column. + * + * In the case of strings, the return size does not include the chars themselves. That + * information is tracked separately (see PageInfo::str_bytes). + */ +struct row_size_functor { + __device__ size_t validity_size(size_t num_rows, bool nullable) + { + return nullable ? (cudf::util::div_rounding_up_safe(num_rows, size_t{32}) * 4) : 0; + } + + template + __device__ size_t operator()(size_t num_rows, bool nullable) + { + auto const element_size = sizeof(device_storage_type_t); + return (element_size * num_rows) + validity_size(num_rows, nullable); + } +}; + +template <> +__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) +{ + auto const offset_size = sizeof(size_type); + // NOTE: Adding the + 1 offset here isn't strictly correct. There will only be 1 extra offset + // for the entire column, whereas this is adding an extra offset per page. So we will get a + // small over-estimate of the real size of the order : # of pages * 4 bytes. It seems better + // to overestimate size somewhat than to underestimate it and potentially generate chunks + // that are too large. + return (offset_size * (num_rows + 1)) + validity_size(num_rows, nullable); +} + +template <> +__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) +{ + return validity_size(num_rows, nullable); +} + +template <> +__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) +{ + // only returns the size of offsets and validity. the size of the actual string chars + // is tracked separately. + auto const offset_size = sizeof(size_type); + // see note about offsets in the list_view template. + return (offset_size * (num_rows + 1)) + validity_size(num_rows, nullable); +} + +/** + * @brief Functor which computes the total output cudf data size for all of + * the data in this page. + * + * Sums across all nesting levels. + */ +struct get_cumulative_row_info { + PageInfo const* const pages; + + __device__ cumulative_row_info operator()(size_type index) + { + auto const& page = pages[index]; + if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { + return cumulative_row_info{0, 0, page.src_col_schema}; + } + + // total nested size, not counting string data + auto iter = + cudf::detail::make_counting_transform_iterator(0, [page, index] __device__(size_type i) { + auto const& pni = page.nesting[i]; + return cudf::type_dispatcher( + data_type{pni.type}, row_size_functor{}, pni.size, pni.nullable); + }); + + size_t const row_count = static_cast(page.nesting[0].size); + return { + row_count, + thrust::reduce(thrust::seq, iter, iter + page.num_output_nesting_levels) + page.str_bytes, + page.src_col_schema}; + } +}; + +/** + * @brief Functor which computes the effective size of all input columns by page. + * + * For a given row, we want to find the cost of all pages for all columns involved + * in loading up to that row. The complication here is that not all pages are the + * same size between columns. Example: + * + * page row counts + * Column A: 0 <----> 100 <----> 200 + * Column B: 0 <---------------> 200 <--------> 400 + | + * if we decide to split at row 100, we don't really know the actual amount of bytes in column B + * at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that + * page. Essentially, a conservative over-estimate of the real size. + */ +struct row_total_size { + cumulative_row_info const* c_info; + size_type const* key_offsets; + size_t num_keys; + + __device__ cumulative_row_info operator()(cumulative_row_info const& i) + { + // sum sizes for each input column at this row + size_t sum = 0; + for (int idx = 0; idx < num_keys; idx++) { + auto const start = key_offsets[idx]; + auto const end = key_offsets[idx + 1]; + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [&] __device__(size_type i) { return c_info[i].row_count; }); + auto const page_index = + thrust::lower_bound(thrust::seq, iter + start, iter + end, i.row_count) - iter; + sum += c_info[page_index].size_bytes; + } + return {i.row_count, sum, i.key}; + } +}; + +/** + * @brief Given a vector of cumulative {row_count, byte_size} pairs and a chunk read + * limit, determine the set of splits. + * + * @param sizes Vector of cumulative {row_count, byte_size} pairs + * @param num_rows Total number of rows to read + * @param chunk_read_limit Limit on total number of bytes to be returned per read, for all columns + */ +std::vector find_splits(std::vector const& sizes, + size_t num_rows, + size_t chunk_read_limit) +{ + // now we have an array of {row_count, real output bytes}. just walk through it and generate + // splits. + // TODO: come up with a clever way to do this entirely in parallel. For now, as long as batch + // sizes are reasonably large, this shouldn't iterate too many times + std::vector splits; + { + size_t cur_pos = 0; + size_t cur_cumulative_size = 0; + size_t cur_row_count = 0; + auto start = thrust::make_transform_iterator(sizes.begin(), [&](cumulative_row_info const& i) { + return i.size_bytes - cur_cumulative_size; + }); + auto end = start + sizes.size(); + while (cur_row_count < num_rows) { + int64_t split_pos = + thrust::lower_bound(thrust::seq, start + cur_pos, end, chunk_read_limit) - start; + + // if we're past the end, or if the returned bucket is > than the chunk_read_limit, move back + // one. + if (static_cast(split_pos) >= sizes.size() || + (sizes[split_pos].size_bytes - cur_cumulative_size > chunk_read_limit)) { + split_pos--; + } + + // best-try. if we can't find something that'll fit, we have to go bigger. we're doing this in + // a loop because all of the cumulative sizes for all the pages are sorted into one big list. + // so if we had two columns, both of which had an entry {1000, 10000}, that entry would be in + // the list twice. so we have to iterate until we skip past all of them. The idea is that we + // either do this, or we have to call unique() on the input first. + while (split_pos < (static_cast(sizes.size()) - 1) && + (split_pos < 0 || sizes[split_pos].row_count == cur_row_count)) { + split_pos++; + } + + auto const start_row = cur_row_count; + cur_row_count = sizes[split_pos].row_count; + splits.push_back(chunk_read_info{start_row, cur_row_count - start_row}); + cur_pos = split_pos; + cur_cumulative_size = sizes[split_pos].size_bytes; + } + } + // print_cumulative_row_info(sizes, "adjusted", splits); + + return splits; +} + +/** + * @brief Converts cuDF units to Parquet units. + * + * @return A tuple of Parquet type width, Parquet clock rate and Parquet decimal type. + */ +[[nodiscard]] std::tuple conversion_info(type_id column_type_id, + type_id timestamp_type_id, + Type physical, + int8_t converted, + int32_t length) +{ + int32_t type_width = (physical == FIXED_LEN_BYTE_ARRAY) ? length : 0; + int32_t clock_rate = 0; + if (column_type_id == type_id::INT8 or column_type_id == type_id::UINT8) { + type_width = 1; // I32 -> I8 + } else if (column_type_id == type_id::INT16 or column_type_id == type_id::UINT16) { + type_width = 2; // I32 -> I16 + } else if (column_type_id == type_id::INT32) { + type_width = 4; // str -> hash32 + } else if (is_chrono(data_type{column_type_id})) { + clock_rate = to_clockrate(timestamp_type_id); + } + + int8_t converted_type = converted; + if (converted_type == DECIMAL && column_type_id != type_id::FLOAT64 && + not cudf::is_fixed_point(data_type{column_type_id})) { + converted_type = UNKNOWN; // Not converting to float64 or decimal + } + return std::make_tuple(type_width, clock_rate, converted_type); +} + +/** + * @brief Return the required number of bits to store a value. + */ +template +[[nodiscard]] T required_bits(uint32_t max_level) +{ + return static_cast(CompactProtocolReader::NumRequiredBits(max_level)); +} + +struct row_count_compare { + __device__ bool operator()(cumulative_row_info const& a, cumulative_row_info const& b) + { + return a.row_count < b.row_count; + } +}; + +} // anonymous namespace + +void reader::impl::create_global_chunk_info() +{ + auto const num_rows = _file_itm_data.global_num_rows; + auto const& row_groups_info = _file_itm_data.row_groups; + auto& chunks = _file_itm_data.chunks; + + // Descriptors for all the chunks that make up the selected columns + auto const num_input_columns = _input_columns.size(); + auto const num_chunks = row_groups_info.size() * num_input_columns; + + // Initialize column chunk information + auto remaining_rows = num_rows; + for (auto const& rg : row_groups_info) { + auto const& row_group = _metadata->get_row_group(rg.index, rg.source_index); + auto const row_group_start = rg.start_row; + auto const row_group_rows = std::min(remaining_rows, row_group.num_rows); + + // generate ColumnChunkDesc objects for everything to be decoded (all input columns) + for (size_t i = 0; i < num_input_columns; ++i) { + auto col = _input_columns[i]; + // look up metadata + auto& col_meta = _metadata->get_column_metadata(rg.index, rg.source_index, col.schema_idx); + auto& schema = _metadata->get_schema(col.schema_idx); + + auto [type_width, clock_rate, converted_type] = + conversion_info(to_type_id(schema, _strings_to_categorical, _timestamp_type.id()), + _timestamp_type.id(), + schema.type, + schema.converted_type, + schema.type_length); + + chunks.push_back(ColumnChunkDesc(col_meta.total_compressed_size, + nullptr, + col_meta.num_values, + schema.type, + type_width, + row_group_start, + row_group_rows, + schema.max_definition_level, + schema.max_repetition_level, + _metadata->get_output_nesting_depth(col.schema_idx), + required_bits(schema.max_definition_level), + required_bits(schema.max_repetition_level), + col_meta.codec, + converted_type, + schema.logical_type, + schema.decimal_precision, + clock_rate, + i, + col.schema_idx)); + } + + remaining_rows -= row_group_rows; + } +} + +void reader::impl::compute_input_passes() +{ + // at this point, row_groups has already been filtered down to just the row groups we need to + // handle optional skip_rows/num_rows parameters. + auto const& row_groups_info = _file_itm_data.row_groups; + + // if the user hasn't specified an input size limit, read everything in a single pass. + if (_input_pass_read_limit == 0) { + _file_itm_data.input_pass_row_group_offsets.push_back(0); + _file_itm_data.input_pass_row_group_offsets.push_back(row_groups_info.size()); + return; + } + + // generate passes. make sure to account for the case where a single row group doesn't fit within + // + std::size_t const read_limit = + _input_pass_read_limit > 0 ? _input_pass_read_limit : std::numeric_limits::max(); + std::size_t cur_pass_byte_size = 0; + std::size_t cur_rg_start = 0; + std::size_t cur_row_count = 0; + _file_itm_data.input_pass_row_group_offsets.push_back(0); + _file_itm_data.input_pass_row_count.push_back(0); + + for (size_t cur_rg_index = 0; cur_rg_index < row_groups_info.size(); cur_rg_index++) { + auto const& rgi = row_groups_info[cur_rg_index]; + auto const& row_group = _metadata->get_row_group(rgi.index, rgi.source_index); + + // can we add this row group + if (cur_pass_byte_size + row_group.total_byte_size >= read_limit) { + // A single row group (the current one) is larger than the read limit: + // We always need to include at least one row group, so end the pass at the end of the current + // row group + if (cur_rg_start == cur_rg_index) { + _file_itm_data.input_pass_row_group_offsets.push_back(cur_rg_index + 1); + _file_itm_data.input_pass_row_count.push_back(cur_row_count + row_group.num_rows); + cur_rg_start = cur_rg_index + 1; + cur_pass_byte_size = 0; + } + // End the pass at the end of the previous row group + else { + _file_itm_data.input_pass_row_group_offsets.push_back(cur_rg_index); + _file_itm_data.input_pass_row_count.push_back(cur_row_count); + cur_rg_start = cur_rg_index; + cur_pass_byte_size = row_group.total_byte_size; + } + } else { + cur_pass_byte_size += row_group.total_byte_size; + } + cur_row_count += row_group.num_rows; + } + // add the last pass if necessary + if (_file_itm_data.input_pass_row_group_offsets.back() != row_groups_info.size()) { + _file_itm_data.input_pass_row_group_offsets.push_back(row_groups_info.size()); + _file_itm_data.input_pass_row_count.push_back(cur_row_count); + } +} + +void reader::impl::setup_next_pass() +{ + // this will also cause the previous pass information to be deleted + _pass_itm_data = std::make_unique(); + + // setup row groups to be loaded for this pass + auto const row_group_start = _file_itm_data.input_pass_row_group_offsets[_current_input_pass]; + auto const row_group_end = _file_itm_data.input_pass_row_group_offsets[_current_input_pass + 1]; + auto const num_row_groups = row_group_end - row_group_start; + _pass_itm_data->row_groups.resize(num_row_groups); + std::copy(_file_itm_data.row_groups.begin() + row_group_start, + _file_itm_data.row_groups.begin() + row_group_end, + _pass_itm_data->row_groups.begin()); + + auto const num_passes = _file_itm_data.input_pass_row_group_offsets.size() - 1; + CUDF_EXPECTS(_current_input_pass < num_passes, "Encountered an invalid read pass index"); + + auto const chunks_per_rowgroup = _input_columns.size(); + auto const num_chunks = chunks_per_rowgroup * num_row_groups; + + auto chunk_start = _file_itm_data.chunks.begin() + (row_group_start * chunks_per_rowgroup); + auto chunk_end = _file_itm_data.chunks.begin() + (row_group_end * chunks_per_rowgroup); + + _pass_itm_data->chunks = cudf::detail::hostdevice_vector(num_chunks, _stream); + std::copy(chunk_start, chunk_end, _pass_itm_data->chunks.begin()); + + // adjust skip_rows and num_rows by what's available in the row groups we are processing + if (num_passes == 1) { + _pass_itm_data->skip_rows = _file_itm_data.global_skip_rows; + _pass_itm_data->num_rows = _file_itm_data.global_num_rows; + } else { + auto const global_start_row = _file_itm_data.global_skip_rows; + auto const global_end_row = global_start_row + _file_itm_data.global_num_rows; + auto const start_row = + std::max(_file_itm_data.input_pass_row_count[_current_input_pass], global_start_row); + auto const end_row = + std::min(_file_itm_data.input_pass_row_count[_current_input_pass + 1], global_end_row); + + // skip_rows is always global in the sense that it is relative to the first row of + // everything we will be reading, regardless of what pass we are on. + // num_rows is how many rows we are reading this pass. + _pass_itm_data->skip_rows = + global_start_row + _file_itm_data.input_pass_row_count[_current_input_pass]; + _pass_itm_data->num_rows = end_row - start_row; + } +} + +void reader::impl::compute_splits_for_pass() +{ + auto const skip_rows = _pass_itm_data->skip_rows; + auto const num_rows = _pass_itm_data->num_rows; + + // simple case : no chunk size, no splits + if (_output_chunk_read_limit <= 0) { + _pass_itm_data->output_chunk_read_info = std::vector{{skip_rows, num_rows}}; + return; + } + + auto& pages = _pass_itm_data->pages_info; + + auto const& page_keys = _pass_itm_data->page_keys; + auto const& page_index = _pass_itm_data->page_index; + + // generate cumulative row counts and sizes + rmm::device_uvector c_info(page_keys.size(), _stream); + // convert PageInfo to cumulative_row_info + auto page_input = thrust::make_transform_iterator(page_index.begin(), + get_cumulative_row_info{pages.device_ptr()}); + thrust::inclusive_scan_by_key(rmm::exec_policy(_stream), + page_keys.begin(), + page_keys.end(), + page_input, + c_info.begin(), + thrust::equal_to{}, + cumulative_row_sum{}); + // print_cumulative_page_info(pages, page_index, c_info, stream); + + // sort by row count + rmm::device_uvector c_info_sorted{c_info, _stream}; + thrust::sort( + rmm::exec_policy(_stream), c_info_sorted.begin(), c_info_sorted.end(), row_count_compare{}); + + // std::vector h_c_info_sorted(c_info_sorted.size()); + // CUDF_CUDA_TRY(cudaMemcpy(h_c_info_sorted.data(), + // c_info_sorted.data(), + // sizeof(cumulative_row_info) * c_info_sorted.size(), + // cudaMemcpyDefault)); + // print_cumulative_row_info(h_c_info_sorted, "raw"); + + // generate key offsets (offsets to the start of each partition of keys). worst case is 1 page per + // key + rmm::device_uvector key_offsets(page_keys.size() + 1, _stream); + auto const key_offsets_end = thrust::reduce_by_key(rmm::exec_policy(_stream), + page_keys.begin(), + page_keys.end(), + thrust::make_constant_iterator(1), + thrust::make_discard_iterator(), + key_offsets.begin()) + .second; + size_t const num_unique_keys = key_offsets_end - key_offsets.begin(); + thrust::exclusive_scan( + rmm::exec_policy(_stream), key_offsets.begin(), key_offsets.end(), key_offsets.begin()); + + // adjust the cumulative info such that for each row count, the size includes any pages that span + // that row count. this is so that if we have this case: + // page row counts + // Column A: 0 <----> 100 <----> 200 + // Column B: 0 <---------------> 200 <--------> 400 + // | + // if we decide to split at row 100, we don't really know the actual amount of bytes in column B + // at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that + // page. + // + rmm::device_uvector aggregated_info(c_info.size(), _stream); + thrust::transform(rmm::exec_policy(_stream), + c_info_sorted.begin(), + c_info_sorted.end(), + aggregated_info.begin(), + row_total_size{c_info.data(), key_offsets.data(), num_unique_keys}); + + // bring back to the cpu + std::vector h_aggregated_info(aggregated_info.size()); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_aggregated_info.data(), + aggregated_info.data(), + sizeof(cumulative_row_info) * c_info.size(), + cudaMemcpyDefault, + _stream.value())); + _stream.synchronize(); + + // generate the actual splits + _pass_itm_data->output_chunk_read_info = + find_splits(h_aggregated_info, num_rows, _output_chunk_read_limit); +} + +} // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_chunking.hpp b/cpp/src/io/parquet/reader_impl_chunking.hpp new file mode 100644 index 00000000000..dfc239d8451 --- /dev/null +++ b/cpp/src/io/parquet/reader_impl_chunking.hpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "reader_impl_helpers.hpp" + +#include + +namespace cudf::io::parquet::detail { + +/** + * @brief Struct to store file-level data that remains constant for + * all passes/chunks in the file. + */ +struct file_intermediate_data { + // all row groups to read + std::vector row_groups{}; + + // all chunks from the selected row groups. We may end up reading these chunks progressively + // instead of all at once + std::vector chunks{}; + + // an array of offsets into _file_itm_data::global_chunks. Each pair of offsets represents + // the start/end of the chunks to be loaded for a given pass. + std::vector input_pass_row_group_offsets{}; + // row counts per input-pass + std::vector input_pass_row_count{}; + + // skip_rows/num_rows values for the entire file. these need to be adjusted per-pass because we + // may not be visiting every row group that contains these bounds + size_t global_skip_rows; + size_t global_num_rows; +}; + +/** + * @brief Struct to identify the range for each chunk of rows during a chunked reading pass. + */ +struct chunk_read_info { + size_t skip_rows; + size_t num_rows; +}; + +/** + * @brief Struct to store pass-level data that remains constant for a single pass. + */ +struct pass_intermediate_data { + std::vector> raw_page_data; + rmm::device_buffer decomp_page_data; + + // rowgroup, chunk and page information for the current pass. + std::vector row_groups{}; + cudf::detail::hostdevice_vector chunks{}; + cudf::detail::hostdevice_vector pages_info{}; + cudf::detail::hostdevice_vector page_nesting_info{}; + cudf::detail::hostdevice_vector page_nesting_decode_info{}; + + rmm::device_uvector page_keys{0, rmm::cuda_stream_default}; + rmm::device_uvector page_index{0, rmm::cuda_stream_default}; + rmm::device_uvector str_dict_index{0, rmm::cuda_stream_default}; + + std::vector output_chunk_read_info; + std::size_t current_output_chunk{0}; + + rmm::device_buffer level_decode_data{}; + int level_type_size{0}; + + // skip_rows and num_rows values for this particular pass. these may be adjusted values from the + // global values stored in file_intermediate_data. + size_t skip_rows; + size_t num_rows; +}; + +} // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index 1a73e2f55ac..8d8ab8707be 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -34,6 +34,23 @@ namespace cudf::io::parquet::detail { +/** + * @brief The row_group_info class + */ +struct row_group_info { + size_type index; // row group index within a file. aggregate_reader_metadata::get_row_group() is + // called with index and source_index + size_t start_row; + size_type source_index; // file index. + + row_group_info() = default; + + row_group_info(size_type index, size_t start_row, size_type source_index) + : index{index}, start_row{start_row}, source_index{source_index} + { + } +}; + /** * @brief Function that translates Parquet datatype to cuDF type enum */ diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 4bc6bb6f43b..ce45f709ee1 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -18,7 +18,6 @@ #include #include -#include #include #include @@ -44,7 +43,6 @@ #include namespace cudf::io::parquet::detail { - namespace { /** @@ -170,46 +168,6 @@ void generate_depth_remappings(std::map, std::ve } } -/** - * @brief Return the required number of bits to store a value. - */ -template -[[nodiscard]] T required_bits(uint32_t max_level) -{ - return static_cast(CompactProtocolReader::NumRequiredBits(max_level)); -} - -/** - * @brief Converts cuDF units to Parquet units. - * - * @return A tuple of Parquet type width, Parquet clock rate and Parquet decimal type. - */ -[[nodiscard]] std::tuple conversion_info(type_id column_type_id, - type_id timestamp_type_id, - Type physical, - int8_t converted, - int32_t length) -{ - int32_t type_width = (physical == FIXED_LEN_BYTE_ARRAY) ? length : 0; - int32_t clock_rate = 0; - if (column_type_id == type_id::INT8 or column_type_id == type_id::UINT8) { - type_width = 1; // I32 -> I8 - } else if (column_type_id == type_id::INT16 or column_type_id == type_id::UINT16) { - type_width = 2; // I32 -> I16 - } else if (column_type_id == type_id::INT32) { - type_width = 4; // str -> hash32 - } else if (is_chrono(data_type{column_type_id})) { - clock_rate = to_clockrate(timestamp_type_id); - } - - int8_t converted_type = converted; - if (converted_type == DECIMAL && column_type_id != type_id::FLOAT64 && - not cudf::is_fixed_point(data_type{column_type_id})) { - converted_type = UNKNOWN; // Not converting to float64 or decimal - } - return std::make_tuple(type_width, clock_rate, converted_type); -} - /** * @brief Reads compressed page data to device memory. * @@ -790,163 +748,6 @@ std::pair>> reader::impl::read_and_decompres return {total_decompressed_size > 0, std::move(read_chunk_tasks)}; } -void reader::impl::load_global_chunk_info() -{ - auto const num_rows = _file_itm_data.global_num_rows; - auto const& row_groups_info = _file_itm_data.row_groups; - auto& chunks = _file_itm_data.chunks; - - // Descriptors for all the chunks that make up the selected columns - auto const num_input_columns = _input_columns.size(); - auto const num_chunks = row_groups_info.size() * num_input_columns; - - // Initialize column chunk information - auto remaining_rows = num_rows; - for (auto const& rg : row_groups_info) { - auto const& row_group = _metadata->get_row_group(rg.index, rg.source_index); - auto const row_group_start = rg.start_row; - auto const row_group_rows = std::min(remaining_rows, row_group.num_rows); - - // generate ColumnChunkDesc objects for everything to be decoded (all input columns) - for (size_t i = 0; i < num_input_columns; ++i) { - auto col = _input_columns[i]; - // look up metadata - auto& col_meta = _metadata->get_column_metadata(rg.index, rg.source_index, col.schema_idx); - auto& schema = _metadata->get_schema(col.schema_idx); - - auto [type_width, clock_rate, converted_type] = - conversion_info(to_type_id(schema, _strings_to_categorical, _timestamp_type.id()), - _timestamp_type.id(), - schema.type, - schema.converted_type, - schema.type_length); - - chunks.push_back(ColumnChunkDesc(col_meta.total_compressed_size, - nullptr, - col_meta.num_values, - schema.type, - type_width, - row_group_start, - row_group_rows, - schema.max_definition_level, - schema.max_repetition_level, - _metadata->get_output_nesting_depth(col.schema_idx), - required_bits(schema.max_definition_level), - required_bits(schema.max_repetition_level), - col_meta.codec, - converted_type, - schema.logical_type, - schema.decimal_precision, - clock_rate, - i, - col.schema_idx)); - } - - remaining_rows -= row_group_rows; - } -} - -void reader::impl::compute_input_pass_row_group_info() -{ - // at this point, row_groups has already been filtered down to just the row groups we need to - // handle optional skip_rows/num_rows parameters. - auto const& row_groups_info = _file_itm_data.row_groups; - - // if the user hasn't specified an input size limit, read everything in a single pass. - if (_input_pass_read_limit == 0) { - _input_pass_row_group_offsets.push_back(0); - _input_pass_row_group_offsets.push_back(row_groups_info.size()); - return; - } - - // generate passes. make sure to account for the case where a single row group doesn't fit within - // - std::size_t const read_limit = - _input_pass_read_limit > 0 ? _input_pass_read_limit : std::numeric_limits::max(); - std::size_t cur_pass_byte_size = 0; - std::size_t cur_rg_start = 0; - std::size_t cur_row_count = 0; - _input_pass_row_group_offsets.push_back(0); - _input_pass_row_count.push_back(0); - - for (size_t cur_rg_index = 0; cur_rg_index < row_groups_info.size(); cur_rg_index++) { - auto const& rgi = row_groups_info[cur_rg_index]; - auto const& row_group = _metadata->get_row_group(rgi.index, rgi.source_index); - - // can we add this row group - if (cur_pass_byte_size + row_group.total_byte_size >= read_limit) { - // A single row group (the current one) is larger than the read limit: - // We always need to include at least one row group, so end the pass at the end of the current - // row group - if (cur_rg_start == cur_rg_index) { - _input_pass_row_group_offsets.push_back(cur_rg_index + 1); - _input_pass_row_count.push_back(cur_row_count + row_group.num_rows); - cur_rg_start = cur_rg_index + 1; - cur_pass_byte_size = 0; - } - // End the pass at the end of the previous row group - else { - _input_pass_row_group_offsets.push_back(cur_rg_index); - _input_pass_row_count.push_back(cur_row_count); - cur_rg_start = cur_rg_index; - cur_pass_byte_size = row_group.total_byte_size; - } - } else { - cur_pass_byte_size += row_group.total_byte_size; - } - cur_row_count += row_group.num_rows; - } - // add the last pass if necessary - if (_input_pass_row_group_offsets.back() != row_groups_info.size()) { - _input_pass_row_group_offsets.push_back(row_groups_info.size()); - _input_pass_row_count.push_back(cur_row_count); - } -} - -void reader::impl::setup_pass() -{ - // this will also cause the previous pass information to be deleted - _pass_itm_data = std::make_unique(); - - // setup row groups to be loaded for this pass - auto const row_group_start = _input_pass_row_group_offsets[_current_input_pass]; - auto const row_group_end = _input_pass_row_group_offsets[_current_input_pass + 1]; - auto const num_row_groups = row_group_end - row_group_start; - _pass_itm_data->row_groups.resize(num_row_groups); - std::copy(_file_itm_data.row_groups.begin() + row_group_start, - _file_itm_data.row_groups.begin() + row_group_end, - _pass_itm_data->row_groups.begin()); - - auto const num_passes = _input_pass_row_group_offsets.size() - 1; - CUDF_EXPECTS(_current_input_pass < num_passes, "Encountered an invalid read pass index"); - - auto const chunks_per_rowgroup = _input_columns.size(); - auto const num_chunks = chunks_per_rowgroup * num_row_groups; - - auto chunk_start = _file_itm_data.chunks.begin() + (row_group_start * chunks_per_rowgroup); - auto chunk_end = _file_itm_data.chunks.begin() + (row_group_end * chunks_per_rowgroup); - - _pass_itm_data->chunks = cudf::detail::hostdevice_vector(num_chunks, _stream); - std::copy(chunk_start, chunk_end, _pass_itm_data->chunks.begin()); - - // adjust skip_rows and num_rows by what's available in the row groups we are processing - if (num_passes == 1) { - _pass_itm_data->skip_rows = _file_itm_data.global_skip_rows; - _pass_itm_data->num_rows = _file_itm_data.global_num_rows; - } else { - auto const global_start_row = _file_itm_data.global_skip_rows; - auto const global_end_row = global_start_row + _file_itm_data.global_num_rows; - auto const start_row = std::max(_input_pass_row_count[_current_input_pass], global_start_row); - auto const end_row = std::min(_input_pass_row_count[_current_input_pass + 1], global_end_row); - - // skip_rows is always global in the sense that it is relative to the first row of - // everything we will be reading, regardless of what pass we are on. - // num_rows is how many rows we are reading this pass. - _pass_itm_data->skip_rows = global_start_row + _input_pass_row_count[_current_input_pass]; - _pass_itm_data->num_rows = end_row - start_row; - } -} - void reader::impl::load_and_decompress_data() { // This function should never be called if `num_rows == 0`. @@ -1034,359 +835,8 @@ void print_pages(cudf::detail::hostdevice_vector& pages, rmm::cuda_str p.str_bytes); } } - -void print_cumulative_page_info(cudf::detail::hostdevice_vector& pages, - rmm::device_uvector const& page_index, - rmm::device_uvector const& c_info, - rmm::cuda_stream_view stream) -{ - pages.device_to_host_sync(stream); - - printf("------------\nCumulative sizes by page\n"); - - std::vector schemas(pages.size()); - std::vector h_page_index(pages.size()); - CUDF_CUDA_TRY(cudaMemcpy( - h_page_index.data(), page_index.data(), sizeof(int) * pages.size(), cudaMemcpyDefault)); - std::vector h_cinfo(pages.size()); - CUDF_CUDA_TRY(cudaMemcpy( - h_cinfo.data(), c_info.data(), sizeof(cumulative_row_info) * pages.size(), cudaMemcpyDefault)); - auto schema_iter = cudf::detail::make_counting_transform_iterator( - 0, [&](size_type i) { return pages[h_page_index[i]].src_col_schema; }); - thrust::copy(thrust::seq, schema_iter, schema_iter + pages.size(), schemas.begin()); - auto last = thrust::unique(thrust::seq, schemas.begin(), schemas.end()); - schemas.resize(last - schemas.begin()); - printf("Num schemas: %lu\n", schemas.size()); - - for (size_t idx = 0; idx < schemas.size(); idx++) { - printf("Schema %d\n", schemas[idx]); - for (size_t pidx = 0; pidx < pages.size(); pidx++) { - auto const& page = pages[h_page_index[pidx]]; - if (page.flags & PAGEINFO_FLAGS_DICTIONARY || page.src_col_schema != schemas[idx]) { - continue; - } - printf("\tP: {%lu, %lu}\n", h_cinfo[pidx].row_count, h_cinfo[pidx].size_bytes); - } - } -} - -void print_cumulative_row_info(host_span sizes, - std::string const& label, - std::optional> splits = std::nullopt) -{ - if (splits.has_value()) { - printf("------------\nSplits\n"); - for (size_t idx = 0; idx < splits->size(); idx++) { - printf("{%lu, %lu}\n", splits.value()[idx].skip_rows, splits.value()[idx].num_rows); - } - } - - printf("------------\nCumulative sizes %s\n", label.c_str()); - for (size_t idx = 0; idx < sizes.size(); idx++) { - printf("{%lu, %lu, %d}", sizes[idx].row_count, sizes[idx].size_bytes, sizes[idx].key); - if (splits.has_value()) { - // if we have a split at this row count and this is the last instance of this row count - auto start = thrust::make_transform_iterator( - splits->begin(), [](chunk_read_info const& i) { return i.skip_rows; }); - auto end = start + splits->size(); - auto split = std::find(start, end, sizes[idx].row_count); - auto const split_index = [&]() -> int { - if (split != end && - ((idx == sizes.size() - 1) || (sizes[idx + 1].row_count > sizes[idx].row_count))) { - return static_cast(std::distance(start, split)); - } - return idx == 0 ? 0 : -1; - }(); - if (split_index >= 0) { - printf(" <-- split {%lu, %lu}", - splits.value()[split_index].skip_rows, - splits.value()[split_index].num_rows); - } - } - printf("\n"); - } -} #endif // PREPROCESS_DEBUG -/** - * @brief Functor which reduces two cumulative_row_info structs of the same key. - */ -struct cumulative_row_sum { - cumulative_row_info operator() - __device__(cumulative_row_info const& a, cumulative_row_info const& b) const - { - return cumulative_row_info{a.row_count + b.row_count, a.size_bytes + b.size_bytes, a.key}; - } -}; - -/** - * @brief Functor which computes the total data size for a given type of cudf column. - * - * In the case of strings, the return size does not include the chars themselves. That - * information is tracked separately (see PageInfo::str_bytes). - */ -struct row_size_functor { - __device__ size_t validity_size(size_t num_rows, bool nullable) - { - return nullable ? (cudf::util::div_rounding_up_safe(num_rows, size_t{32}) * 4) : 0; - } - - template - __device__ size_t operator()(size_t num_rows, bool nullable) - { - auto const element_size = sizeof(device_storage_type_t); - return (element_size * num_rows) + validity_size(num_rows, nullable); - } -}; - -template <> -__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) -{ - auto const offset_size = sizeof(size_type); - // NOTE: Adding the + 1 offset here isn't strictly correct. There will only be 1 extra offset - // for the entire column, whereas this is adding an extra offset per page. So we will get a - // small over-estimate of the real size of the order : # of pages * 4 bytes. It seems better - // to overestimate size somewhat than to underestimate it and potentially generate chunks - // that are too large. - return (offset_size * (num_rows + 1)) + validity_size(num_rows, nullable); -} - -template <> -__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) -{ - return validity_size(num_rows, nullable); -} - -template <> -__device__ size_t row_size_functor::operator()(size_t num_rows, bool nullable) -{ - // only returns the size of offsets and validity. the size of the actual string chars - // is tracked separately. - auto const offset_size = sizeof(size_type); - // see note about offsets in the list_view template. - return (offset_size * (num_rows + 1)) + validity_size(num_rows, nullable); -} - -/** - * @brief Functor which computes the total output cudf data size for all of - * the data in this page. - * - * Sums across all nesting levels. - */ -struct get_cumulative_row_info { - PageInfo const* const pages; - - __device__ cumulative_row_info operator()(size_type index) - { - auto const& page = pages[index]; - if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { - return cumulative_row_info{0, 0, page.src_col_schema}; - } - - // total nested size, not counting string data - auto iter = - cudf::detail::make_counting_transform_iterator(0, [page, index] __device__(size_type i) { - auto const& pni = page.nesting[i]; - return cudf::type_dispatcher( - data_type{pni.type}, row_size_functor{}, pni.size, pni.nullable); - }); - - size_t const row_count = static_cast(page.nesting[0].size); - return { - row_count, - thrust::reduce(thrust::seq, iter, iter + page.num_output_nesting_levels) + page.str_bytes, - page.src_col_schema}; - } -}; - -/** - * @brief Functor which computes the effective size of all input columns by page. - * - * For a given row, we want to find the cost of all pages for all columns involved - * in loading up to that row. The complication here is that not all pages are the - * same size between columns. Example: - * - * page row counts - * Column A: 0 <----> 100 <----> 200 - * Column B: 0 <---------------> 200 <--------> 400 - | - * if we decide to split at row 100, we don't really know the actual amount of bytes in column B - * at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that - * page. Essentially, a conservative over-estimate of the real size. - */ -struct row_total_size { - cumulative_row_info const* c_info; - size_type const* key_offsets; - size_t num_keys; - - __device__ cumulative_row_info operator()(cumulative_row_info const& i) - { - // sum sizes for each input column at this row - size_t sum = 0; - for (int idx = 0; idx < num_keys; idx++) { - auto const start = key_offsets[idx]; - auto const end = key_offsets[idx + 1]; - auto iter = cudf::detail::make_counting_transform_iterator( - 0, [&] __device__(size_type i) { return c_info[i].row_count; }); - auto const page_index = - thrust::lower_bound(thrust::seq, iter + start, iter + end, i.row_count) - iter; - sum += c_info[page_index].size_bytes; - } - return {i.row_count, sum, i.key}; - } -}; - -/** - * @brief Given a vector of cumulative {row_count, byte_size} pairs and a chunk read - * limit, determine the set of splits. - * - * @param sizes Vector of cumulative {row_count, byte_size} pairs - * @param num_rows Total number of rows to read - * @param chunk_read_limit Limit on total number of bytes to be returned per read, for all columns - */ -std::vector find_splits(std::vector const& sizes, - size_t num_rows, - size_t chunk_read_limit) -{ - // now we have an array of {row_count, real output bytes}. just walk through it and generate - // splits. - // TODO: come up with a clever way to do this entirely in parallel. For now, as long as batch - // sizes are reasonably large, this shouldn't iterate too many times - std::vector splits; - { - size_t cur_pos = 0; - size_t cur_cumulative_size = 0; - size_t cur_row_count = 0; - auto start = thrust::make_transform_iterator(sizes.begin(), [&](cumulative_row_info const& i) { - return i.size_bytes - cur_cumulative_size; - }); - auto end = start + sizes.size(); - while (cur_row_count < num_rows) { - int64_t split_pos = - thrust::lower_bound(thrust::seq, start + cur_pos, end, chunk_read_limit) - start; - - // if we're past the end, or if the returned bucket is > than the chunk_read_limit, move back - // one. - if (static_cast(split_pos) >= sizes.size() || - (sizes[split_pos].size_bytes - cur_cumulative_size > chunk_read_limit)) { - split_pos--; - } - - // best-try. if we can't find something that'll fit, we have to go bigger. we're doing this in - // a loop because all of the cumulative sizes for all the pages are sorted into one big list. - // so if we had two columns, both of which had an entry {1000, 10000}, that entry would be in - // the list twice. so we have to iterate until we skip past all of them. The idea is that we - // either do this, or we have to call unique() on the input first. - while (split_pos < (static_cast(sizes.size()) - 1) && - (split_pos < 0 || sizes[split_pos].row_count == cur_row_count)) { - split_pos++; - } - - auto const start_row = cur_row_count; - cur_row_count = sizes[split_pos].row_count; - splits.push_back(chunk_read_info{start_row, cur_row_count - start_row}); - cur_pos = split_pos; - cur_cumulative_size = sizes[split_pos].size_bytes; - } - } - // print_cumulative_row_info(sizes, "adjusted", splits); - - return splits; -} - -/** - * @brief Given a set of pages that have had their sizes computed by nesting level and - * a limit on total read size, generate a set of {skip_rows, num_rows} pairs representing - * a set of reads that will generate output columns of total size <= `chunk_read_limit` bytes. - * - * @param pages All pages in the file - * @param id Additional intermediate information required to process the pages - * @param num_rows Total number of rows to read - * @param chunk_read_limit Limit on total number of bytes to be returned per read, for all columns - * @param stream CUDA stream to use - */ -std::vector compute_splits(cudf::detail::hostdevice_vector& pages, - pass_intermediate_data const& id, - size_t num_rows, - size_t chunk_read_limit, - rmm::cuda_stream_view stream) -{ - auto const& page_keys = id.page_keys; - auto const& page_index = id.page_index; - - // generate cumulative row counts and sizes - rmm::device_uvector c_info(page_keys.size(), stream); - // convert PageInfo to cumulative_row_info - auto page_input = thrust::make_transform_iterator(page_index.begin(), - get_cumulative_row_info{pages.device_ptr()}); - thrust::inclusive_scan_by_key(rmm::exec_policy(stream), - page_keys.begin(), - page_keys.end(), - page_input, - c_info.begin(), - thrust::equal_to{}, - cumulative_row_sum{}); - // print_cumulative_page_info(pages, page_index, c_info, stream); - - // sort by row count - rmm::device_uvector c_info_sorted{c_info, stream}; - thrust::sort(rmm::exec_policy(stream), - c_info_sorted.begin(), - c_info_sorted.end(), - [] __device__(cumulative_row_info const& a, cumulative_row_info const& b) { - return a.row_count < b.row_count; - }); - - // std::vector h_c_info_sorted(c_info_sorted.size()); - // CUDF_CUDA_TRY(cudaMemcpy(h_c_info_sorted.data(), - // c_info_sorted.data(), - // sizeof(cumulative_row_info) * c_info_sorted.size(), - // cudaMemcpyDefault)); - // print_cumulative_row_info(h_c_info_sorted, "raw"); - - // generate key offsets (offsets to the start of each partition of keys). worst case is 1 page per - // key - rmm::device_uvector key_offsets(page_keys.size() + 1, stream); - auto const key_offsets_end = thrust::reduce_by_key(rmm::exec_policy(stream), - page_keys.begin(), - page_keys.end(), - thrust::make_constant_iterator(1), - thrust::make_discard_iterator(), - key_offsets.begin()) - .second; - size_t const num_unique_keys = key_offsets_end - key_offsets.begin(); - thrust::exclusive_scan( - rmm::exec_policy(stream), key_offsets.begin(), key_offsets.end(), key_offsets.begin()); - - // adjust the cumulative info such that for each row count, the size includes any pages that span - // that row count. this is so that if we have this case: - // page row counts - // Column A: 0 <----> 100 <----> 200 - // Column B: 0 <---------------> 200 <--------> 400 - // | - // if we decide to split at row 100, we don't really know the actual amount of bytes in column B - // at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that - // page. - // - rmm::device_uvector aggregated_info(c_info.size(), stream); - thrust::transform(rmm::exec_policy(stream), - c_info_sorted.begin(), - c_info_sorted.end(), - aggregated_info.begin(), - row_total_size{c_info.data(), key_offsets.data(), num_unique_keys}); - - // bring back to the cpu - std::vector h_aggregated_info(aggregated_info.size()); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_aggregated_info.data(), - aggregated_info.data(), - sizeof(cumulative_row_info) * c_info.size(), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); - - return find_splits(h_aggregated_info, num_rows, chunk_read_limit); -} - struct get_page_chunk_idx { __device__ size_type operator()(PageInfo const& page) { return page.chunk_idx; } }; @@ -1822,12 +1272,8 @@ void reader::impl::preprocess_pages(bool uses_custom_row_bounds, size_t chunk_re _pass_itm_data->page_keys = std::move(page_keys); _pass_itm_data->page_index = std::move(page_index); - // compute splits if necessary. otherwise return a single split representing - // the whole file. - _pass_itm_data->output_chunk_read_info = - _output_chunk_read_limit > 0 - ? compute_splits(pages, *_pass_itm_data, num_rows, chunk_read_limit, _stream) - : std::vector{{skip_rows, num_rows}}; + // compute splits for the pass + compute_splits_for_pass(); } void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses_custom_row_bounds) diff --git a/cpp/src/io/utilities/column_buffer.cpp b/cpp/src/io/utilities/column_buffer.cpp index f3a43cbc63c..dd049d401cf 100644 --- a/cpp/src/io/utilities/column_buffer.cpp +++ b/cpp/src/io/utilities/column_buffer.cpp @@ -51,19 +51,21 @@ std::unique_ptr gather_column_buffer::make_string_column_impl(rmm::cuda_ return make_strings_column(*_strings, stream, _mr); } -void inline_column_buffer::allocate_strings_data(rmm::cuda_stream_view stream) +void cudf::io::detail::inline_column_buffer::allocate_strings_data(rmm::cuda_stream_view stream) { CUDF_EXPECTS(type.id() == type_id::STRING, "allocate_strings_data called for non-string column"); // size + 1 for final offset. _string_data will be initialized later. _data = create_data(data_type{type_id::INT32}, size + 1, stream, _mr); } -void inline_column_buffer::create_string_data(size_t num_bytes, rmm::cuda_stream_view stream) +void cudf::io::detail::inline_column_buffer::create_string_data(size_t num_bytes, + rmm::cuda_stream_view stream) { _string_data = rmm::device_buffer(num_bytes, stream, _mr); } -std::unique_ptr inline_column_buffer::make_string_column_impl(rmm::cuda_stream_view stream) +std::unique_ptr cudf::io::detail::inline_column_buffer::make_string_column_impl( + rmm::cuda_stream_view stream) { // no need for copies, just transfer ownership of the data_buffers to the columns auto const state = mask_state::UNALLOCATED; @@ -324,7 +326,7 @@ std::unique_ptr empty_like(column_buffer_base& buffer, } using pointer_type = gather_column_buffer; -using string_type = inline_column_buffer; +using string_type = cudf::io::detail::inline_column_buffer; using pointer_column_buffer = column_buffer_base; using string_column_buffer = column_buffer_base;