diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu index 13ec2d652a6..dc61070441d 100644 --- a/cpp/src/io/parquet/writer_impl.cu +++ b/cpp/src/io/parquet/writer_impl.cu @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,7 @@ #include #include #include +#include #include #include @@ -85,6 +87,29 @@ parquet::Compression to_parquet_compression(compression_type compression) } } +size_type column_size(column_view const& column, rmm::cuda_stream_view stream) +{ + if (column.num_children() == 0) { return size_of(column.type()) * column.size(); } + + if (column.type().id() == type_id::STRING) { + auto const scol = strings_column_view(column); + return cudf::detail::get_value(scol.offsets(), column.size(), stream) - + cudf::detail::get_value(scol.offsets(), 0, stream); + } else if (column.type().id() == type_id::STRUCT) { + auto const scol = structs_column_view(column); + size_type ret = 0; + for (int i = 0; i < scol.num_children(); i++) { + ret += column_size(scol.get_sliced_child(i), stream); + } + return ret; + } else if (column.type().id() == type_id::LIST) { + auto const lcol = lists_column_view(column); + return column_size(lcol.get_sliced_child(stream), stream); + } + + return 0; +} + } // namespace struct aggregate_writer_metadata { @@ -1412,10 +1437,22 @@ void writer::impl::write(table_view const& table, std::vector co // iteratively reduce this value if the largest fragment exceeds the max page size limit (we // ideally want the page size to be below 1MB so as to have enough pages to get good // compression/decompression performance). - // If using the default fragment size, scale it up or down depending on the requested page size. + // If using the default fragment size, adapt fragment size so that page size guarantees are met. if (max_page_fragment_size_ == cudf::io::default_max_page_fragment_size) { max_page_fragment_size_ = (cudf::io::default_max_page_fragment_size * max_page_size_bytes) / cudf::io::default_max_page_size_bytes; + + if (table.num_rows() > 0) { + std::for_each( + table.begin(), table.end(), [this, num_rows = table.num_rows()](auto const& column) { + auto const avg_len = column_size(column, stream) / num_rows; + + if (avg_len > 0) { + size_type frag_size = max_page_size_bytes / avg_len; + max_page_fragment_size_ = std::min(frag_size, max_page_fragment_size_); + } + }); + } } std::vector num_frag_in_part;