From 47d6c1568c8504e714068a909ee1dc639aef0f5d Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Mon, 7 Aug 2023 14:00:02 +0530 Subject: [PATCH] Add memory usage accounting for buffered pages in writer --- .../trino/parquet/writer/ParquetWriter.java | 1 + .../parquet/writer/PrimitiveColumnWriter.java | 1 + .../io/trino/parquet/ParquetTestUtils.java | 24 ++++++++------ .../parquet/writer/TestParquetWriter.java | 32 +++++++++++++++++++ 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 7dd27a813c5b..38d5170e65c6 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -203,6 +203,7 @@ public void close() try (outputStream) { columnWriters.forEach(ColumnWriter::close); flush(); + columnWriters = ImmutableList.of(); writeFooter(); } bufferedBytes = 0; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java index 3d695d8fde1f..9faf5ce07532 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java @@ -297,6 +297,7 @@ public long getBufferedBytes() public long getRetainedBytes() { return INSTANCE_SIZE + + compressedOutputStream.getRetainedSize() + primitiveValueWriter.getAllocatedSize() + definitionLevelWriter.getAllocatedSize() + repetitionLevelWriter.getAllocatedSize(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index be833eddcdb6..06c9c79e5d30 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -37,6 +37,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -66,11 +67,23 @@ private ParquetTestUtils() {} public static Slice writeParquetFile(ParquetWriterOptions writerOptions, List types, List columnNames, List inputPages) throws IOException + { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ParquetWriter writer = createParquetWriter(outputStream, writerOptions, types, columnNames); + + for (io.trino.spi.Page inputPage : inputPages) { + checkArgument(types.size() == inputPage.getChannelCount()); + writer.write(inputPage); + } + writer.close(); + return Slices.wrappedBuffer(outputStream.toByteArray()); + } + + public static ParquetWriter createParquetWriter(OutputStream outputStream, ParquetWriterOptions writerOptions, List types, List columnNames) { checkArgument(types.size() == columnNames.size()); ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter(types, columnNames, false, false); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - ParquetWriter writer = new ParquetWriter( + return new ParquetWriter( outputStream, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), @@ -80,13 +93,6 @@ public static Slice writeParquetFile(ParquetWriterOptions writerOptions, List columnNames = ImmutableList.of("columnA", "columnB"); + List types = ImmutableList.of(INTEGER, INTEGER); + + ParquetWriter writer = createParquetWriter( + new ByteArrayOutputStream(), + ParquetWriterOptions.builder() + .setMaxPageSize(DataSize.ofBytes(1024)) + .build(), + types, + columnNames); + List inputPages = generateInputPages(types, 1000, 100); + + long previousRetainedBytes = 0; + for (io.trino.spi.Page inputPage : inputPages) { + checkArgument(types.size() == inputPage.getChannelCount()); + writer.write(inputPage); + long currentRetainedBytes = writer.getRetainedBytes(); + assertThat(currentRetainedBytes).isGreaterThanOrEqualTo(previousRetainedBytes); + previousRetainedBytes = currentRetainedBytes; + } + assertThat(previousRetainedBytes).isGreaterThanOrEqualTo(2 * Integer.BYTES * 1000 * 100); + writer.close(); + assertThat(previousRetainedBytes - writer.getRetainedBytes()).isGreaterThanOrEqualTo(2 * Integer.BYTES * 1000 * 100); + } }