diff --git a/cpp/include/cudf/io/detail/orc.hpp b/cpp/include/cudf/io/detail/orc.hpp index 3c1486b60c2..c63c952e148 100644 --- a/cpp/include/cudf/io/detail/orc.hpp +++ b/cpp/include/cudf/io/detail/orc.hpp @@ -124,14 +124,6 @@ class writer { * @brief Finishes the chunked/streamed write process. */ void close(); - - /** - * @brief Skip work done in `close()`; should be called if `write()` failed. - * - * Calling skip_close() prevents the writer from writing the (invalid) file footer and the - * postscript. - */ - void skip_close(); }; } // namespace orc::detail } // namespace cudf::io diff --git a/cpp/src/io/functions.cpp b/cpp/src/io/functions.cpp index b8353d312fe..46c6c67c8df 100644 --- a/cpp/src/io/functions.cpp +++ b/cpp/src/io/functions.cpp @@ -436,16 +436,7 @@ void write_orc(orc_writer_options const& options, rmm::cuda_stream_view stream) auto writer = std::make_unique( std::move(sinks[0]), options, io_detail::single_write_mode::YES, stream); - try { - writer->write(options.get_table()); - } catch (...) { - // If an exception is thrown, the output is incomplete/corrupted. - // Make sure the writer will not close with such corrupted data. - // In addition, the writer may throw an exception while trying to close, which would terminate - // the process. - writer->skip_close(); - throw; - } + writer->write(options.get_table()); } /** diff --git a/cpp/src/io/orc/writer_impl.cu b/cpp/src/io/orc/writer_impl.cu index ade0e75de35..750a593920c 100644 --- a/cpp/src/io/orc/writer_impl.cu +++ b/cpp/src/io/orc/writer_impl.cu @@ -2438,7 +2438,6 @@ writer::impl::impl(std::unique_ptr sink, if (options.get_metadata()) { _table_meta = std::make_unique(*options.get_metadata()); } - init_state(); } writer::impl::impl(std::unique_ptr sink, @@ -2460,20 +2459,13 @@ writer::impl::impl(std::unique_ptr sink, if (options.get_metadata()) { _table_meta = std::make_unique(*options.get_metadata()); } - init_state(); } writer::impl::~impl() { close(); } -void writer::impl::init_state() -{ - // Write file header - _out_sink->host_write(MAGIC, std::strlen(MAGIC)); -} - void writer::impl::write(table_view const& input) { - CUDF_EXPECTS(not _closed, "Data has already been flushed to out and closed"); + CUDF_EXPECTS(_state != writer_state::CLOSED, "Data has already been flushed to out and closed"); if (not _table_meta) { _table_meta = make_table_meta(input); } @@ -2516,6 +2508,11 @@ void writer::impl::write(table_view const& input) } }(); + if (_state == writer_state::NO_DATA_WRITTEN) { + // Write the ORC file header if this is the first write + _out_sink->host_write(MAGIC, std::strlen(MAGIC)); + } + // Compression/encoding were all successful. Now write the intermediate results. write_orc_data_to_sink(enc_data, segmentation, @@ -2533,6 +2530,8 @@ void writer::impl::write(table_view const& input) // Update file-level and compression statistics update_statistics(orc_table.num_rows(), std::move(intermediate_stats), compression_stats); + + _state = writer_state::DATA_WRITTEN; } void writer::impl::update_statistics( @@ -2683,8 +2682,11 @@ void writer::impl::add_table_to_footer_data(orc_table_view const& orc_table, void writer::impl::close() { - if (_closed) { return; } - _closed = true; + if (_state != writer_state::DATA_WRITTEN) { + // writer is either closed or no data has been written + _state = writer_state::CLOSED; + return; + } PostScript ps; if (_stats_freq != statistics_freq::STATISTICS_NONE) { @@ -2769,6 +2771,8 @@ void writer::impl::close() pbw.put_byte(ps_length); _out_sink->host_write(pbw.data(), pbw.size()); _out_sink->flush(); + + _state = writer_state::CLOSED; } // Forward to implementation @@ -2795,9 +2799,6 @@ writer::~writer() = default; // Forward to implementation void writer::write(table_view const& table) { _impl->write(table); } -// Forward to implementation -void writer::skip_close() { _impl->skip_close(); } - // Forward to implementation void writer::close() { _impl->close(); } diff --git a/cpp/src/io/orc/writer_impl.hpp b/cpp/src/io/orc/writer_impl.hpp index 417d29efb58..bd082befe0c 100644 --- a/cpp/src/io/orc/writer_impl.hpp +++ b/cpp/src/io/orc/writer_impl.hpp @@ -227,6 +227,14 @@ struct encoded_footer_statistics { std::vector file_level; }; +enum class writer_state { + NO_DATA_WRITTEN, // No table data has been written to the sink; if the writer is closed or + // destroyed in this state, it should not write the footer. + DATA_WRITTEN, // At least one table has been written to the sink; when the writer is closed, + // it should write the footer. + CLOSED // Writer has been closed; no further writes are allowed. +}; + /** * @brief Implementation for ORC writer */ @@ -266,11 +274,6 @@ class writer::impl { */ ~impl(); - /** - * @brief Begins the chunked/streamed write process. - */ - void init_state(); - /** * @brief Writes a single subtable as part of a larger ORC file/table write. * @@ -283,11 +286,6 @@ class writer::impl { */ void close(); - /** - * @brief Skip writing the footer when closing/deleting the writer. - */ - void skip_close() { _closed = true; } - private: /** * @brief Write the intermediate ORC data into the data sink. @@ -363,7 +361,7 @@ class writer::impl { Footer _footer; Metadata _orc_meta; persisted_statistics _persisted_stripe_statistics; // Statistics data saved between calls. - bool _closed = false; // To track if the output has been written to sink. + writer_state _state = writer_state::NO_DATA_WRITTEN; }; } // namespace cudf::io::orc::detail diff --git a/cpp/tests/io/orc_test.cpp b/cpp/tests/io/orc_test.cpp index 24e2e2cfea0..e108e68e1f9 100644 --- a/cpp/tests/io/orc_test.cpp +++ b/cpp/tests/io/orc_test.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -2100,8 +2101,7 @@ TEST_F(OrcWriterTest, BounceBufferBug) auto sequence = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 100; }); constexpr auto num_rows = 150000; - column_wrapper col(sequence, - sequence + num_rows); + column_wrapper col(sequence, sequence + num_rows); table_view expected({col}); auto filepath = temp_env->get_temp_filepath("BounceBufferBug.orc"); @@ -2120,8 +2120,7 @@ TEST_F(OrcReaderTest, SizeTypeRowsOverflow) static_assert(total_rows > std::numeric_limits::max()); auto sequence = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 127; }); - column_wrapper col(sequence, - sequence + num_rows); + column_wrapper col(sequence, sequence + num_rows); table_view chunk_table({col}); std::vector out_buffer; @@ -2169,4 +2168,55 @@ TEST_F(OrcReaderTest, SizeTypeRowsOverflow) CUDF_TEST_EXPECT_TABLES_EQUAL(expected, got_with_stripe_selection->view()); } +TEST_F(OrcChunkedWriterTest, NoWriteCloseNotThrow) +{ + std::vector out_buffer; + + cudf::io::chunked_orc_writer_options write_opts = + cudf::io::chunked_orc_writer_options::builder(cudf::io::sink_info{&out_buffer}); + auto writer = cudf::io::orc_chunked_writer(write_opts); + + EXPECT_NO_THROW(writer.close()); +} + +TEST_F(OrcChunkedWriterTest, FailedWriteCloseNotThrow) +{ + // A sink that throws on write() + class throw_sink : public cudf::io::data_sink { + public: + void host_write(void const* data, size_t size) override { throw std::runtime_error("write"); } + void flush() override {} + size_t bytes_written() override { return 0; } + }; + + auto sequence = thrust::make_counting_iterator(0); + column_wrapper col(sequence, sequence + 10); + table_view table({col}); + + throw_sink sink; + cudf::io::chunked_orc_writer_options write_opts = + cudf::io::chunked_orc_writer_options::builder(cudf::io::sink_info{&sink}); + auto writer = cudf::io::orc_chunked_writer(write_opts); + + try { + writer.write(table); + } catch (...) { + // ignore the exception; we're testing that close() doesn't throw when the only write() fails + } + + EXPECT_NO_THROW(writer.close()); +} + +TEST_F(OrcChunkedWriterTest, NoDataInSinkWhenNoWrite) +{ + std::vector out_buffer; + + cudf::io::chunked_orc_writer_options write_opts = + cudf::io::chunked_orc_writer_options::builder(cudf::io::sink_info{&out_buffer}); + auto writer = cudf::io::orc_chunked_writer(write_opts); + + EXPECT_NO_THROW(writer.close()); + EXPECT_EQ(out_buffer.size(), 0); +} + CUDF_TEST_PROGRAM_MAIN()