Skip to content

Commit

Permalink
Add native parquet writer
Browse files Browse the repository at this point in the history
  • Loading branch information
qqibrow authored and dain committed Feb 19, 2020
1 parent 5cb9de3 commit 70b1289
Show file tree
Hide file tree
Showing 31 changed files with 3,293 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.prestosql.orc.OrcWriterStats;
import io.prestosql.orc.OutputStreamOrcDataSink;
import io.prestosql.orc.metadata.OrcType;
import io.prestosql.parquet.writer.ParquetWriter;
import io.prestosql.parquet.writer.ParquetWriterOptions;
import io.prestosql.plugin.hive.FileFormatDataSourceStats;
import io.prestosql.plugin.hive.HdfsEnvironment;
import io.prestosql.plugin.hive.HiveColumnHandle;
Expand Down Expand Up @@ -170,8 +172,9 @@ public FormatWriter createFileFormatWriter(
List<String> columnNames,
List<Type> columnTypes,
HiveCompressionCodec compressionCodec)
throws IOException
{
return new RecordFormatWriter(targetFile, columnNames, columnTypes, compressionCodec, HiveStorageFormat.PARQUET, session);
return new PrestoParquetFormatWriter(targetFile, columnNames, columnTypes, compressionCodec);
}
},

Expand Down Expand Up @@ -477,4 +480,35 @@ public void close()
writer.close();
}
}

private static class PrestoParquetFormatWriter
implements FormatWriter
{
private final ParquetWriter writer;

public PrestoParquetFormatWriter(File targetFile, List<String> columnNames, List<Type> types, HiveCompressionCodec compressionCodec)
throws IOException
{
writer = new ParquetWriter(
new FileOutputStream(targetFile),
columnNames,
types,
ParquetWriterOptions.builder().build(),
compressionCodec.getParquetCompressionCodec());
}

@Override
public void writePage(Page page)
throws IOException
{
writer.write(page);
}

@Override
public void close()
throws IOException
{
writer.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.prestosql.parquet.writer.ParquetWriter;
import io.prestosql.parquet.writer.ParquetWriterOptions;
import io.prestosql.plugin.hive.HiveConfig;
import io.prestosql.plugin.hive.HiveSessionProperties;
import io.prestosql.plugin.hive.HiveStorageFormat;
Expand All @@ -31,21 +35,25 @@
import io.prestosql.plugin.hive.parquet.write.SingleLevelArraySchemaConverter;
import io.prestosql.plugin.hive.parquet.write.TestMapredParquetOutputFormat;
import io.prestosql.spi.Page;
import io.prestosql.spi.PageBuilder;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.connector.ConnectorPageSource;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.RecordCursor;
import io.prestosql.spi.connector.RecordPageSource;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.CharType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.SqlDate;
import io.prestosql.spi.type.SqlDecimal;
import io.prestosql.spi.type.SqlTimestamp;
import io.prestosql.spi.type.SqlVarbinary;
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.testing.TestingConnectorSession;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter;
Expand All @@ -65,6 +73,7 @@

import java.io.Closeable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigInteger;
Expand All @@ -75,13 +84,17 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Properties;
import java.util.Set;
import java.util.function.Function;

import static com.google.common.base.Functions.constant;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.transform;
import static io.airlift.slice.Slices.utf8Slice;
import static io.prestosql.plugin.hive.AbstractTestHiveFileFormats.getFieldFromCursor;
import static io.prestosql.plugin.hive.HiveSessionProperties.getParquetMaxReadBlockSize;
import static io.prestosql.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT;
Expand All @@ -90,9 +103,21 @@
import static io.prestosql.plugin.hive.util.HiveUtil.isMapType;
import static io.prestosql.plugin.hive.util.HiveUtil.isRowType;
import static io.prestosql.plugin.hive.util.HiveUtil.isStructuralType;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.Chars.truncateToLengthAndTrimSpaces;
import static io.prestosql.spi.type.DateType.DATE;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.IntegerType.INTEGER;
import static io.prestosql.spi.type.RealType.REAL;
import static io.prestosql.spi.type.SmallintType.SMALLINT;
import static io.prestosql.spi.type.TimeZoneKey.UTC_KEY;
import static io.prestosql.spi.type.TimestampType.TIMESTAMP;
import static io.prestosql.spi.type.TinyintType.TINYINT;
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
import static io.prestosql.spi.type.Varchars.isVarcharType;
import static io.prestosql.spi.type.Varchars.truncateToLength;
import static java.lang.Math.toIntExact;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.stream;
import static java.util.Collections.singletonList;
Expand All @@ -115,13 +140,18 @@
public class ParquetTester
{
public static final DateTimeZone HIVE_STORAGE_TIME_ZONE = DateTimeZone.forID("America/Bahia_Banderas");

private static final int MAX_PRECISION_INT64 = toIntExact(maxPrecision(8));

private static final boolean OPTIMIZED = true;
private static final ConnectorSession SESSION = getHiveSession(createHiveConfig(false));
private static final ConnectorSession SESSION_USE_NAME = getHiveSession(createHiveConfig(true));
private static final List<String> TEST_COLUMN = singletonList("test");

private Set<CompressionCodecName> compressions = ImmutableSet.of();

private Set<CompressionCodecName> writerCompressions = ImmutableSet.of();

private Set<WriterVersion> versions = ImmutableSet.of();

private Set<ConnectorSession> sessions = ImmutableSet.of();
Expand All @@ -130,6 +160,7 @@ public static ParquetTester quickParquetTester()
{
ParquetTester parquetTester = new ParquetTester();
parquetTester.compressions = ImmutableSet.of(GZIP);
parquetTester.writerCompressions = ImmutableSet.of(GZIP);
parquetTester.versions = ImmutableSet.of(PARQUET_1_0);
parquetTester.sessions = ImmutableSet.of(SESSION);
return parquetTester;
Expand All @@ -139,6 +170,7 @@ public static ParquetTester fullParquetTester()
{
ParquetTester parquetTester = new ParquetTester();
parquetTester.compressions = ImmutableSet.of(GZIP, UNCOMPRESSED, SNAPPY, LZO, LZ4, ZSTD);
parquetTester.writerCompressions = ImmutableSet.of(GZIP, UNCOMPRESSED, SNAPPY, ZSTD);
parquetTester.versions = ImmutableSet.copyOf(WriterVersion.values());
parquetTester.sessions = ImmutableSet.of(SESSION, SESSION_USE_NAME);
return parquetTester;
Expand Down Expand Up @@ -311,6 +343,23 @@ void assertRoundTrip(
}
}
}

// write presto parquet
for (CompressionCodecName compressionCodecName : writerCompressions) {
for (ConnectorSession session : sessions) {
try (TempFile tempFile = new TempFile("test", "parquet")) {
OptionalInt min = stream(writeValues).mapToInt(Iterables::size).min();
checkState(min.isPresent());
writeParquetColumnPresto(tempFile.getFile(), columnTypes, columnNames, getIterators(readValues), min.getAsInt(), compressionCodecName);
assertFileContents(
session,
tempFile.getFile(),
getIterators(readValues),
columnNames,
columnTypes);
}
}
}
}

static void testMaxReadBytes(ObjectInspector objectInspector, Iterable<?> writeValues, Iterable<?> readValues, Type type, DataSize maxReadBlockSize)
Expand Down Expand Up @@ -476,10 +525,10 @@ private static Object getActualCursorValue(RecordCursor cursor, Type type, int f
if (VARBINARY.equals(type)) {
return new SqlVarbinary(((Slice) fieldFromCursor).getBytes());
}
if (DateType.DATE.equals(type)) {
if (DATE.equals(type)) {
return new SqlDate(((Long) fieldFromCursor).intValue());
}
if (TimestampType.TIMESTAMP.equals(type)) {
if (TIMESTAMP.equals(type)) {
return new SqlTimestamp((long) fieldFromCursor, UTC_KEY);
}
return fieldFromCursor;
Expand Down Expand Up @@ -658,4 +707,125 @@ private static Object decodeObject(Type type, Block block, int position)

return type.getObjectValue(SESSION, block, position);
}

private static void writeParquetColumnPresto(File outputFile, List<Type> types, List<String> columnNames, Iterator<?>[] values, int size, CompressionCodecName compressionCodecName)
throws Exception
{
checkArgument(types.size() == columnNames.size() && types.size() == values.length);
ParquetWriter writer = new ParquetWriter(
new FileOutputStream(outputFile),
columnNames,
types,
ParquetWriterOptions.builder().build(),
compressionCodecName);

PageBuilder pageBuilder = new PageBuilder(types);
for (int i = 0; i < types.size(); ++i) {
Type type = types.get(i);
Iterator<?> iterator = values[i];
BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(i);

for (int j = 0; j < size; ++j) {
checkState(iterator.hasNext());
Object value = iterator.next();
writeValue(type, blockBuilder, value);
}
}
pageBuilder.declarePositions(size);
writer.write(pageBuilder.build());
writer.close();
}

private static void writeValue(Type type, BlockBuilder blockBuilder, Object value)
{
if (value == null) {
blockBuilder.appendNull();
}
else {
if (BOOLEAN.equals(type)) {
type.writeBoolean(blockBuilder, (Boolean) value);
}
else if (TINYINT.equals(type) || SMALLINT.equals(type) || INTEGER.equals(type) || BIGINT.equals(type)) {
type.writeLong(blockBuilder, ((Number) value).longValue());
}
else if (Decimals.isShortDecimal(type)) {
type.writeLong(blockBuilder, ((SqlDecimal) value).getUnscaledValue().longValue());
}
else if (Decimals.isLongDecimal(type)) {
if (Decimals.overflows(((SqlDecimal) value).getUnscaledValue(), MAX_PRECISION_INT64)) {
type.writeSlice(blockBuilder, Decimals.encodeUnscaledValue(((SqlDecimal) value).toBigDecimal().unscaledValue()));
}
else {
type.writeSlice(blockBuilder, Decimals.encodeUnscaledValue(((SqlDecimal) value).getUnscaledValue().longValue()));
}
}
else if (DOUBLE.equals(type)) {
type.writeDouble(blockBuilder, ((Number) value).doubleValue());
}
else if (REAL.equals(type)) {
float floatValue = ((Number) value).floatValue();
type.writeLong(blockBuilder, Float.floatToIntBits(floatValue));
}
else if (type instanceof VarcharType) {
Slice slice = truncateToLength(utf8Slice((String) value), type);
type.writeSlice(blockBuilder, slice);
}
else if (type instanceof CharType) {
Slice slice = truncateToLengthAndTrimSpaces(utf8Slice((String) value), type);
type.writeSlice(blockBuilder, slice);
}
else if (VARBINARY.equals(type)) {
type.writeSlice(blockBuilder, Slices.wrappedBuffer(((SqlVarbinary) value).getBytes()));
}
else if (DATE.equals(type)) {
long days = ((SqlDate) value).getDays();
type.writeLong(blockBuilder, days);
}
else if (TIMESTAMP.equals(type)) {
long millis = ((SqlTimestamp) value).getMillisUtc();
type.writeLong(blockBuilder, millis);
}
else {
if (type instanceof ArrayType) {
List<?> array = (List<?>) value;
Type elementType = type.getTypeParameters().get(0);
BlockBuilder arrayBlockBuilder = blockBuilder.beginBlockEntry();
for (Object elementValue : array) {
writeValue(elementType, arrayBlockBuilder, elementValue);
}
blockBuilder.closeEntry();
}
else if (type instanceof MapType) {
Map<?, ?> map = (Map<?, ?>) value;
Type keyType = type.getTypeParameters().get(0);
Type valueType = type.getTypeParameters().get(1);
BlockBuilder mapBlockBuilder = blockBuilder.beginBlockEntry();
for (Map.Entry<?, ?> entry : map.entrySet()) {
writeValue(keyType, mapBlockBuilder, entry.getKey());
writeValue(valueType, mapBlockBuilder, entry.getValue());
}
blockBuilder.closeEntry();
}
else if (type instanceof RowType) {
List<?> array = (List<?>) value;
List<Type> fieldTypes = type.getTypeParameters();
BlockBuilder rowBlockBuilder = blockBuilder.beginBlockEntry();
for (int fieldId = 0; fieldId < fieldTypes.size(); fieldId++) {
Type fieldType = fieldTypes.get(fieldId);
writeValue(fieldType, rowBlockBuilder, array.get(fieldId));
}
blockBuilder.closeEntry();
}
else {
throw new IllegalArgumentException("Unsupported type " + type);
}
}
}
}

// copied from Parquet code to determine the max decimal precision supported by INT32/INT64
private static long maxPrecision(int numBytes)
{
return Math.round(Math.floor(Math.log10(Math.pow(2, 8 * numBytes - 1) - 1)));
}
}
6 changes: 6 additions & 0 deletions presto-parquet/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<optional>true</optional>
</dependency>

<!-- used by tests but also needed transitively -->
<dependency>
<groupId>io.airlift</groupId>
Expand Down
Loading

0 comments on commit 70b1289

Please sign in to comment.