Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use and record writer time zone in ORC files #212

Merged
merged 2 commits into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions presto-orc/src/main/java/io/prestosql/orc/OrcReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ public OrcReader(OrcDataSource orcDataSource, OrcEncoding orcEncoding, DataSize

this.writeValidation = requireNonNull(writeValidation, "writeValidation is null");

validateWrite(validation -> validation.getOrcEncoding() == orcEncoding, "Unexpected ORC encoding");

//
// Read the file tail:
//
Expand Down
10 changes: 6 additions & 4 deletions presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -227,6 +228,7 @@ public OrcRecordReader(

stripeReader = new StripeReader(
orcDataSource,
hiveStorageTimeZone.toTimeZone().toZoneId(),
decompressor,
types,
this.presentColumns,
Expand All @@ -236,7 +238,7 @@ public OrcRecordReader(
metadataReader,
writeValidation);

streamReaders = createStreamReaders(orcDataSource, types, hiveStorageTimeZone, presentColumnsAndTypes.build(), streamReadersSystemMemoryContext);
streamReaders = createStreamReaders(orcDataSource, types, presentColumnsAndTypes.build(), streamReadersSystemMemoryContext);
maxBytesPerCell = new long[streamReaders.length];
nextBatchSize = initialBatchSize;
}
Expand Down Expand Up @@ -511,9 +513,10 @@ private void advanceToNextStripe()
// Give readers access to dictionary streams
InputStreamSources dictionaryStreamSources = stripe.getDictionaryStreamSources();
List<ColumnEncoding> columnEncodings = stripe.getColumnEncodings();
ZoneId timeZone = stripe.getTimeZone();
for (StreamReader column : streamReaders) {
if (column != null) {
column.startStripe(dictionaryStreamSources, columnEncodings);
column.startStripe(timeZone, dictionaryStreamSources, columnEncodings);
}
}

Expand Down Expand Up @@ -553,7 +556,6 @@ private void validateWritePageChecksum()
private static StreamReader[] createStreamReaders(
OrcDataSource orcDataSource,
List<OrcType> types,
DateTimeZone hiveStorageTimeZone,
Map<Integer, Type> includedColumns,
AggregatedMemoryContext systemMemoryContext)
{
Expand All @@ -564,7 +566,7 @@ private static StreamReader[] createStreamReaders(
for (int columnId = 0; columnId < rowType.getFieldCount(); columnId++) {
if (includedColumns.containsKey(columnId)) {
StreamDescriptor streamDescriptor = streamDescriptors.get(columnId);
streamReaders[columnId] = StreamReaders.createStreamReader(streamDescriptor, hiveStorageTimeZone, systemMemoryContext);
streamReaders[columnId] = StreamReaders.createStreamReader(streamDescriptor, systemMemoryContext);
}
}
return streamReaders;
Expand Down
59 changes: 51 additions & 8 deletions presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import io.prestosql.spi.type.VarcharType;
import org.openjdk.jol.info.ClassLayout;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -99,8 +100,10 @@ public enum OrcWriteValidationMode
HASHED, DETAILED, BOTH
}

private final OrcEncoding orcEncoding;
private final List<Integer> version;
private final CompressionKind compression;
private final ZoneId timeZone;
private final int rowGroupMaxRowCount;
private final List<String> columnNames;
private final Map<String, Slice> metadata;
Expand All @@ -111,8 +114,10 @@ public enum OrcWriteValidationMode
private final int stringStatisticsLimitInBytes;

private OrcWriteValidation(
OrcEncoding orcEncoding,
List<Integer> version,
CompressionKind compression,
ZoneId timeZone,
int rowGroupMaxRowCount,
List<String> columnNames,
Map<String, Slice> metadata,
Expand All @@ -122,8 +127,10 @@ private OrcWriteValidation(
List<ColumnStatistics> fileStatistics,
int stringStatisticsLimitInBytes)
{
this.orcEncoding = orcEncoding;
this.version = version;
this.compression = compression;
this.timeZone = timeZone;
this.rowGroupMaxRowCount = rowGroupMaxRowCount;
this.columnNames = columnNames;
this.metadata = metadata;
Expand All @@ -134,6 +141,16 @@ private OrcWriteValidation(
this.stringStatisticsLimitInBytes = stringStatisticsLimitInBytes;
}

public OrcEncoding getOrcEncoding()
{
return orcEncoding;
}

public boolean isDwrf()
{
return orcEncoding == OrcEncoding.DWRF;
}

public List<Integer> getVersion()
{
return version;
Expand All @@ -144,6 +161,20 @@ public CompressionKind getCompression()
return compression;
}

public ZoneId getTimeZone()
{
return timeZone;
}

public void validateTimeZone(OrcDataSourceId orcDataSourceId, ZoneId actualTimeZone)
throws OrcCorruptionException
{
// DWRF does not store the writer time zone
if (!isDwrf() && !timeZone.equals(actualTimeZone)) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected time zone");
}
}

public int getRowGroupMaxRowCount()
{
return rowGroupMaxRowCount;
Expand All @@ -162,12 +193,14 @@ public Map<String, Slice> getMetadata()
public void validateMetadata(OrcDataSourceId orcDataSourceId, Map<String, Slice> actualMetadata)
throws OrcCorruptionException
{
// Filter out metadata value statically added by the DWRF writer
Map<String, Slice> filteredMetadata = actualMetadata.entrySet().stream()
.filter(entry -> !STATIC_METADATA.containsKey(entry.getKey()))
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
if (isDwrf()) {
// Filter out metadata value statically added by the DWRF writer
actualMetadata = actualMetadata.entrySet().stream()
.filter(entry -> !STATIC_METADATA.containsKey(entry.getKey()))
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
}

if (!metadata.equals(filteredMetadata)) {
if (!metadata.equals(actualMetadata)) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected metadata");
}
}
Expand All @@ -180,7 +213,7 @@ public WriteChecksum getChecksum()
public void validateFileStatistics(OrcDataSourceId orcDataSourceId, List<ColumnStatistics> actualFileStatistics)
throws OrcCorruptionException
{
if (actualFileStatistics.isEmpty()) {
if (isDwrf()) {
// DWRF file statistics are disabled
return;
}
Expand All @@ -193,7 +226,7 @@ public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, List<Strip
requireNonNull(actualStripes, "actualStripes is null");
requireNonNull(actualStripeStatistics, "actualStripeStatistics is null");

if (actualStripeStatistics.isEmpty()) {
if (isDwrf()) {
// DWRF does not have stripe statistics
return;
}
Expand Down Expand Up @@ -833,9 +866,11 @@ public static class OrcWriteValidationBuilder
private static final int INSTANCE_SIZE = ClassLayout.parseClass(OrcWriteValidationBuilder.class).instanceSize();

private final OrcWriteValidationMode validationMode;
private final OrcEncoding orcEncoding;

private List<Integer> version;
private CompressionKind compression;
private ZoneId timeZone;
private int rowGroupMaxRowCount;
private int stringStatisticsLimitInBytes;
private List<String> columnNames;
Expand All @@ -847,9 +882,10 @@ public static class OrcWriteValidationBuilder
private List<ColumnStatistics> fileStatistics;
private long retainedSize = INSTANCE_SIZE;

public OrcWriteValidationBuilder(OrcWriteValidationMode validationMode, List<Type> types)
public OrcWriteValidationBuilder(OrcWriteValidationMode validationMode, OrcEncoding orcEncoding, List<Type> types)
{
this.validationMode = validationMode;
this.orcEncoding = orcEncoding;
this.checksum = new WriteChecksumBuilder(types);
}

Expand All @@ -869,6 +905,11 @@ public void setCompression(CompressionKind compression)
this.compression = compression;
}

public void setTimeZone(ZoneId timeZone)
{
this.timeZone = timeZone;
}

public void setRowGroupMaxRowCount(int rowGroupMaxRowCount)
{
this.rowGroupMaxRowCount = rowGroupMaxRowCount;
Expand Down Expand Up @@ -932,8 +973,10 @@ public void setFileStatistics(List<ColumnStatistics> fileStatistics)
public OrcWriteValidation build()
{
return new OrcWriteValidation(
orcEncoding,
version,
compression,
timeZone,
rowGroupMaxRowCount,
columnNames,
metadata,
Expand Down
8 changes: 6 additions & 2 deletions presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@

import java.io.Closeable;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -131,14 +133,15 @@ public OrcWriter(
OrcWriteValidationMode validationMode,
OrcWriterStats stats)
{
this.validationBuilder = validate ? new OrcWriteValidationBuilder(validationMode, types)
this.validationBuilder = validate ? new OrcWriteValidationBuilder(validationMode, orcEncoding, types)
.setStringStatisticsLimitInBytes(toIntExact(options.getMaxStringStatisticsLimit().toBytes())) : null;

this.orcDataSink = requireNonNull(orcDataSink, "orcDataSink is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.orcEncoding = requireNonNull(orcEncoding, "orcEncoding is null");
this.compression = requireNonNull(compression, "compression is null");
recordValidation(validation -> validation.setCompression(compression));
recordValidation(validation -> validation.setTimeZone(hiveStorageTimeZone.toTimeZone().toZoneId()));

requireNonNull(options, "options is null");
checkArgument(options.getStripeMaxSize().compareTo(options.getStripeMinSize()) >= 0, "stripeMaxSize must be greater than stripeMinSize");
Expand Down Expand Up @@ -411,7 +414,8 @@ private List<OrcDataOutput> bufferStripeData(long stripeStartOffset, FlushReason
columnStatistics.put(0, new ColumnStatistics((long) stripeRowCount, 0, null, null, null, null, null, null, null, null));

// add footer
StripeFooter stripeFooter = new StripeFooter(allStreams, toDenseList(columnEncodings, orcTypes.size()));
Optional<ZoneId> timeZone = Optional.of(hiveStorageTimeZone.toTimeZone().toZoneId());
StripeFooter stripeFooter = new StripeFooter(allStreams, toDenseList(columnEncodings, orcTypes.size()), timeZone);
Slice footer = metadataWriter.writeStripeFooter(stripeFooter);
outputData.add(createDataOutput(footer));

Expand Down
11 changes: 10 additions & 1 deletion presto-orc/src/main/java/io/prestosql/orc/Stripe.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.prestosql.orc.metadata.ColumnEncoding;
import io.prestosql.orc.stream.InputStreamSources;

import java.time.ZoneId;
import java.util.List;

import static com.google.common.base.MoreObjects.toStringHelper;
Expand All @@ -25,13 +26,15 @@
public class Stripe
{
private final long rowCount;
private final ZoneId timeZone;
private final List<ColumnEncoding> columnEncodings;
private final List<RowGroup> rowGroups;
private final InputStreamSources dictionaryStreamSources;

public Stripe(long rowCount, List<ColumnEncoding> columnEncodings, List<RowGroup> rowGroups, InputStreamSources dictionaryStreamSources)
public Stripe(long rowCount, ZoneId timeZone, List<ColumnEncoding> columnEncodings, List<RowGroup> rowGroups, InputStreamSources dictionaryStreamSources)
{
this.rowCount = rowCount;
this.timeZone = requireNonNull(timeZone, "timeZone is null");
this.columnEncodings = requireNonNull(columnEncodings, "columnEncodings is null");
this.rowGroups = ImmutableList.copyOf(requireNonNull(rowGroups, "rowGroups is null"));
this.dictionaryStreamSources = requireNonNull(dictionaryStreamSources, "dictionaryStreamSources is null");
Expand All @@ -42,6 +45,11 @@ public long getRowCount()
return rowCount;
}

public ZoneId getTimeZone()
{
return timeZone;
}

public List<ColumnEncoding> getColumnEncodings()
{
return columnEncodings;
Expand All @@ -62,6 +70,7 @@ public String toString()
{
return toStringHelper(this)
.add("rowCount", rowCount)
.add("timeZone", timeZone)
.add("columnEncodings", columnEncodings)
.add("rowGroups", rowGroups)
.add("dictionaryStreams", dictionaryStreamSources)
Expand Down
12 changes: 10 additions & 2 deletions presto-orc/src/main/java/io/prestosql/orc/StripeReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -75,6 +76,7 @@
public class StripeReader
{
private final OrcDataSource orcDataSource;
private final ZoneId defaultTimeZone;
private final Optional<OrcDecompressor> decompressor;
private final List<OrcType> types;
private final HiveWriterVersion hiveWriterVersion;
Expand All @@ -85,6 +87,7 @@ public class StripeReader
private final Optional<OrcWriteValidation> writeValidation;

public StripeReader(OrcDataSource orcDataSource,
ZoneId defaultTimeZone,
Optional<OrcDecompressor> decompressor,
List<OrcType> types,
Set<Integer> includedColumns,
Expand All @@ -95,6 +98,7 @@ public StripeReader(OrcDataSource orcDataSource,
Optional<OrcWriteValidation> writeValidation)
{
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
this.defaultTimeZone = requireNonNull(defaultTimeZone, "defaultTimeZone is null");
this.decompressor = requireNonNull(decompressor, "decompressor is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.includedOrcColumns = getIncludedOrcColumns(types, requireNonNull(includedColumns, "includedColumns is null"));
Expand All @@ -111,6 +115,10 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
// read the stripe footer
StripeFooter stripeFooter = readStripeFooter(stripe, systemMemoryUsage);
List<ColumnEncoding> columnEncodings = stripeFooter.getColumnEncodings();
if (writeValidation.isPresent()) {
writeValidation.get().validateTimeZone(orcDataSource.getId(), stripeFooter.getTimeZone().orElse(null));
}
ZoneId timeZone = stripeFooter.getTimeZone().orElse(defaultTimeZone);

// get streams for selected columns
Map<StreamId, Stream> streams = new HashMap<>();
Expand Down Expand Up @@ -182,7 +190,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
selectedRowGroups,
columnEncodings);

return new Stripe(stripe.getNumberOfRows(), columnEncodings, rowGroups, dictionaryStreamSources);
return new Stripe(stripe.getNumberOfRows(), timeZone, columnEncodings, rowGroups, dictionaryStreamSources);
}
catch (InvalidCheckpointException e) {
// The ORC file contains a corrupt checkpoint stream
Expand Down Expand Up @@ -241,7 +249,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
}
RowGroup rowGroup = new RowGroup(0, 0, stripe.getNumberOfRows(), minAverageRowBytes, new InputStreamSources(builder.build()));

return new Stripe(stripe.getNumberOfRows(), columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
return new Stripe(stripe.getNumberOfRows(), timeZone, columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
}

public Map<StreamId, OrcInputStream> readDiskRanges(long stripeOffset, Map<StreamId, DiskRange> diskRanges, AggregatedMemoryContext systemMemoryUsage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public StripeFooter readStripeFooter(List<OrcType> types, InputStream inputStrea
{
CodedInputStream input = CodedInputStream.newInstance(inputStream);
DwrfProto.StripeFooter stripeFooter = DwrfProto.StripeFooter.parseFrom(input);
return new StripeFooter(toStream(stripeFooter.getStreamsList()), toColumnEncoding(types, stripeFooter.getColumnsList()));
return new StripeFooter(toStream(stripeFooter.getStreamsList()), toColumnEncoding(types, stripeFooter.getColumnsList()), Optional.empty());
}

private static Stream toStream(DwrfProto.Stream stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public StripeFooter readStripeFooter(List<OrcType> types, InputStream inputStrea
try {
return delegate.readStripeFooter(types, inputStream);
}
catch (IOException e) {
catch (IOException | RuntimeException e) {
throw propagate(e, "Invalid stripe footer");
}
}
Expand Down
Loading