diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 74fff993552a..b7c2db255b69 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -115,7 +115,11 @@ impl ArrowReaderBuilder { } /// Set the size of [`RecordBatch`] to produce. Defaults to 1024 + /// If the batch_size more than the file row count, use the file row count. pub fn with_batch_size(self, batch_size: usize) -> Self { + // Try to avoid allocate large buffer + let batch_size = + batch_size.min(self.metadata.file_metadata().num_rows() as usize); Self { batch_size, ..self } } diff --git a/parquet/src/arrow/async_reader.rs b/parquet/src/arrow/async_reader.rs index b0d9143d64d1..201f2afcf0e8 100644 --- a/parquet/src/arrow/async_reader.rs +++ b/parquet/src/arrow/async_reader.rs @@ -329,6 +329,10 @@ impl ArrowReaderBuilder> { None => (0..self.metadata.row_groups().len()).collect(), }; + // Try to avoid allocate large buffer + let batch_size = self + .batch_size + .min(self.metadata.file_metadata().num_rows() as usize); let reader = ReaderFactory { input: self.input.0, filter: self.filter, @@ -338,7 +342,7 @@ impl ArrowReaderBuilder> { Ok(ParquetRecordBatchStream { metadata: self.metadata, - batch_size: self.batch_size, + batch_size, row_groups, projection: self.projection, selection: self.selection, @@ -1133,4 +1137,34 @@ mod tests { assert_eq!(&requests[..], &expected_page_requests) } + + #[tokio::test] + async fn test_batch_size_overallocate() { + let testdata = arrow::util::test_util::parquet_test_data(); + // `alltypes_plain.parquet` only have 8 rows + let path = format!("{}/alltypes_plain.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = parse_metadata(&data).unwrap(); + let file_rows = metadata.file_metadata().num_rows() as usize; + let metadata = Arc::new(metadata); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let stream = builder + .with_projection(ProjectionMask::all()) + .with_batch_size(1024) + .build() + .unwrap(); + assert_ne!(1024, file_rows); + assert_eq!(stream.batch_size, file_rows as usize); + } }