Skip to content

Commit

Permalink
Add memory usage accounting for buffered pages in writer
Browse files Browse the repository at this point in the history
  • Loading branch information
raunaqmorarka committed Aug 9, 2023
1 parent 2389844 commit 47d6c15
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public void close()
try (outputStream) {
columnWriters.forEach(ColumnWriter::close);
flush();
columnWriters = ImmutableList.of();
writeFooter();
}
bufferedBytes = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ public long getBufferedBytes()
public long getRetainedBytes()
{
return INSTANCE_SIZE +
compressedOutputStream.getRetainedSize() +
primitiveValueWriter.getAllocatedSize() +
definitionLevelWriter.getAllocatedSize() +
repetitionLevelWriter.getAllocatedSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -66,11 +67,23 @@ private ParquetTestUtils() {}

public static Slice writeParquetFile(ParquetWriterOptions writerOptions, List<Type> types, List<String> columnNames, List<io.trino.spi.Page> inputPages)
throws IOException
{
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ParquetWriter writer = createParquetWriter(outputStream, writerOptions, types, columnNames);

for (io.trino.spi.Page inputPage : inputPages) {
checkArgument(types.size() == inputPage.getChannelCount());
writer.write(inputPage);
}
writer.close();
return Slices.wrappedBuffer(outputStream.toByteArray());
}

public static ParquetWriter createParquetWriter(OutputStream outputStream, ParquetWriterOptions writerOptions, List<Type> types, List<String> columnNames)
{
checkArgument(types.size() == columnNames.size());
ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter(types, columnNames, false, false);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ParquetWriter writer = new ParquetWriter(
return new ParquetWriter(
outputStream,
schemaConverter.getMessageType(),
schemaConverter.getPrimitiveTypes(),
Expand All @@ -80,13 +93,6 @@ public static Slice writeParquetFile(ParquetWriterOptions writerOptions, List<Ty
false,
Optional.of(DateTimeZone.getDefault()),
Optional.empty());

for (io.trino.spi.Page inputPage : inputPages) {
checkArgument(types.size() == inputPage.getChannelCount());
writer.write(inputPage);
}
writer.close();
return Slices.wrappedBuffer(outputStream.toByteArray());
}

public static ParquetReader createParquetReader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@
import org.apache.parquet.schema.PrimitiveType;
import org.testng.annotations.Test;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.parquet.ParquetTestUtils.createParquetWriter;
import static io.trino.parquet.ParquetTestUtils.generateInputPages;
import static io.trino.parquet.ParquetTestUtils.writeParquetFile;
import static io.trino.spi.type.BigintType.BIGINT;
Expand Down Expand Up @@ -143,4 +146,33 @@ public void testColumnReordering()
assertThat(offsets).isSorted();
}
}

@Test
public void testWriterMemoryAccounting()
throws IOException
{
List<String> columnNames = ImmutableList.of("columnA", "columnB");
List<Type> types = ImmutableList.of(INTEGER, INTEGER);

ParquetWriter writer = createParquetWriter(
new ByteArrayOutputStream(),
ParquetWriterOptions.builder()
.setMaxPageSize(DataSize.ofBytes(1024))
.build(),
types,
columnNames);
List<io.trino.spi.Page> inputPages = generateInputPages(types, 1000, 100);

long previousRetainedBytes = 0;
for (io.trino.spi.Page inputPage : inputPages) {
checkArgument(types.size() == inputPage.getChannelCount());
writer.write(inputPage);
long currentRetainedBytes = writer.getRetainedBytes();
assertThat(currentRetainedBytes).isGreaterThanOrEqualTo(previousRetainedBytes);
previousRetainedBytes = currentRetainedBytes;
}
assertThat(previousRetainedBytes).isGreaterThanOrEqualTo(2 * Integer.BYTES * 1000 * 100);
writer.close();
assertThat(previousRetainedBytes - writer.getRetainedBytes()).isGreaterThanOrEqualTo(2 * Integer.BYTES * 1000 * 100);
}
}

0 comments on commit 47d6c15

Please sign in to comment.