diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27d783cd89b5..694c94928537 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -48,6 +48,7 @@ use datafusion_common::{ DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -749,9 +750,13 @@ impl DataSink for ParquetSink { parquet_props.writer_options().clone(), ) .await?; + let mut reservation = + MemoryConsumer::new(format!("ParquetSink[{}]", path)) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; + reservation.try_resize(writer.memory_size())?; } let file_metadata = writer .close() @@ -771,6 +776,7 @@ impl DataSink for ParquetSink { let schema = self.get_writer_schema(); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); + let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { let file_metadata = output_single_parquet_file_parallelized( writer, @@ -778,6 +784,7 @@ impl DataSink for ParquetSink { schema, props.writer_options(), parallel_options_clone, + pool, ) .await?; Ok((path, file_metadata)) @@ -818,14 +825,16 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, -) -> Result { + mut reservation: MemoryReservation, +) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; + reservation.try_resize(writer.memory_size())?; } - Ok(writer) + Ok((writer, reservation)) } -type ColumnWriterTask = SpawnedTask>; +type ColumnWriterTask = SpawnedTask>; type ColSender = Sender; /// Spawns a parallel serialization task for each column @@ -835,6 +844,7 @@ fn spawn_column_parallel_row_group_writer( schema: Arc, parquet_props: Arc, max_buffer_size: usize, + pool: &Arc, ) -> Result<(Vec, Vec)> { let schema_desc = arrow_to_parquet_schema(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; @@ -848,7 +858,13 @@ fn spawn_column_parallel_row_group_writer( mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); - let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer)); + let reservation = + MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); + let task = SpawnedTask::spawn(column_serializer_task( + recieve_array, + writer, + reservation, + )); col_writer_tasks.push(task); } @@ -864,7 +880,7 @@ struct ParallelParquetWriterOptions { /// This is the return type of calling [ArrowColumnWriter].close() on each column /// i.e. the Vec of encoded columns which can be appended to a row group -type RBStreamSerializeResult = Result<(Vec, usize)>; +type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; /// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective /// parallel column serializers. @@ -895,16 +911,22 @@ async fn send_arrays_to_col_writers( fn spawn_rg_join_and_finalize_task( column_writer_tasks: Vec, rg_rows: usize, + pool: &Arc, ) -> SpawnedTask { + let mut rg_reservation = + MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); + SpawnedTask::spawn(async move { let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - let writer = task.join_unwind().await?; + let (writer, _col_reservation) = task.join_unwind().await?; + let encoded_size = writer.get_estimated_total_bytes(); + rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); } - Ok((finalized_rg, rg_rows)) + Ok((finalized_rg, rg_reservation, rg_rows)) }) } @@ -922,6 +944,7 @@ fn spawn_parquet_parallel_serialization_task( schema: Arc, writer_props: Arc, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> SpawnedTask> { SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; @@ -931,6 +954,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; let mut current_rg_rows = 0; @@ -957,6 +981,7 @@ fn spawn_parquet_parallel_serialization_task( let finalize_rg_task = spawn_rg_join_and_finalize_task( column_writer_handles, max_row_group_rows, + &pool, ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { @@ -973,6 +998,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; } } @@ -981,8 +1007,11 @@ fn spawn_parquet_parallel_serialization_task( drop(col_array_channels); // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows if current_rg_rows > 0 { - let finalize_rg_task = - spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + current_rg_rows, + &pool, + ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { DataFusionError::Internal( @@ -1002,9 +1031,13 @@ async fn concatenate_parallel_row_groups( schema: Arc, writer_props: Arc, mut object_store_writer: Box, + pool: Arc, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut file_reservation = + MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( merged_buff.clone(), @@ -1015,15 +1048,20 @@ async fn concatenate_parallel_row_groups( while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, _cnt) = result?; + let (serialized_columns, mut rg_reservation, _cnt) = result?; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; + rg_reservation.free(); + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + file_reservation.try_resize(buff_to_flush.len())?; + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; buff_to_flush.clear(); + file_reservation.try_resize(buff_to_flush.len())?; // will set to zero } } rg_out.close()?; @@ -1034,6 +1072,7 @@ async fn concatenate_parallel_row_groups( object_store_writer.write_all(final_buff.as_slice()).await?; object_store_writer.shutdown().await?; + file_reservation.free(); Ok(file_metadata) } @@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized( output_schema: Arc, parquet_props: &WriterProperties, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> Result { let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel @@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized( output_schema.clone(), arc_props.clone(), parallel_options, + Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, output_schema.clone(), arc_props.clone(), object_store_writer, + pool, ) .await?; @@ -1158,8 +1200,10 @@ mod tests { use super::super::test_util::scan_format; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::physical_plan::collect; + use crate::test_util::bounded_stream; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; use super::*; @@ -2177,4 +2221,105 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn parquet_sink_write_memory_reservation() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + overwrite: true, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); + + let mut write_task = parquet_sink.write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; + } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); + + Ok(()) + } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) + } } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 059fa8fc6da7..ba0509f3f51a 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +struct BoundedStream { + limit: usize, + count: usize, + batch: RecordBatch, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + return Poll::Ready(None); + } + self.count += 1; + Poll::Ready(Some(Ok(self.batch.clone()))) + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// Creates an bounded stream for testing purposes. +pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + count: 0, + limit, + batch, + }) +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f61ee5d9ab98..f7402357d1c7 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -31,6 +31,7 @@ use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; use std::any::Any; use std::sync::{Arc, OnceLock}; +use tokio::fs::File; use datafusion::datasource::streaming::StreamingTable; use datafusion::datasource::{MemTable, TableProvider}; @@ -323,6 +324,30 @@ async fn oom_recursive_cte() { .await } +#[tokio::test] +async fn oom_parquet_sink() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + // TODO: update error handling in ParquetSink + "Unable to send array to writer!", + ]) + .with_memory_limit(200_000) + .run() + .await +} + /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)]