diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7289dd2462300..217ef5aad24ee 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1191,8 +1191,10 @@ mod tests { use super::super::test_util::scan_format; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::physical_plan::collect; + use crate::test_util::bounded_stream; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; use super::*; @@ -2210,4 +2212,105 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn parquet_sink_write_memory_reservation() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + overwrite: true, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); + + let mut write_task = parquet_sink.write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; + } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); + + Ok(()) + } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) + } }