Skip to content

Commit

Permalink
Use and record writer time zone in ORC files
Browse files Browse the repository at this point in the history
Early versions of the Apache ORC writer made the mistake of recording
timestamps from an epoch that was relative to the time zone of the
writer. This was fixed in later versions by recording the writer time
zone in the stripe footer. Hive 3.1 always writes using UTC.

Presto used a global configuration for the writer time zone, which
was needed to handle old files, but was never updated to use the time
zone from the stripe footer.

On read, Presto now uses the stripe value if present, otherwise it
uses the configured value. On write, Presto continues to write
timestamps using the configured time zone, but now records this value
when writing files.
  • Loading branch information
electrum committed Feb 15, 2019
1 parent 82ff60d commit f086b52
Show file tree
Hide file tree
Showing 29 changed files with 147 additions and 71 deletions.
10 changes: 6 additions & 4 deletions presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -227,6 +228,7 @@ public OrcRecordReader(

stripeReader = new StripeReader(
orcDataSource,
hiveStorageTimeZone.toTimeZone().toZoneId(),
decompressor,
types,
this.presentColumns,
Expand All @@ -236,7 +238,7 @@ public OrcRecordReader(
metadataReader,
writeValidation);

streamReaders = createStreamReaders(orcDataSource, types, hiveStorageTimeZone, presentColumnsAndTypes.build(), streamReadersSystemMemoryContext);
streamReaders = createStreamReaders(orcDataSource, types, presentColumnsAndTypes.build(), streamReadersSystemMemoryContext);
maxBytesPerCell = new long[streamReaders.length];
nextBatchSize = initialBatchSize;
}
Expand Down Expand Up @@ -511,9 +513,10 @@ private void advanceToNextStripe()
// Give readers access to dictionary streams
InputStreamSources dictionaryStreamSources = stripe.getDictionaryStreamSources();
List<ColumnEncoding> columnEncodings = stripe.getColumnEncodings();
ZoneId timeZone = stripe.getTimeZone();
for (StreamReader column : streamReaders) {
if (column != null) {
column.startStripe(dictionaryStreamSources, columnEncodings);
column.startStripe(timeZone, dictionaryStreamSources, columnEncodings);
}
}

Expand Down Expand Up @@ -553,7 +556,6 @@ private void validateWritePageChecksum()
private static StreamReader[] createStreamReaders(
OrcDataSource orcDataSource,
List<OrcType> types,
DateTimeZone hiveStorageTimeZone,
Map<Integer, Type> includedColumns,
AggregatedMemoryContext systemMemoryContext)
{
Expand All @@ -564,7 +566,7 @@ private static StreamReader[] createStreamReaders(
for (int columnId = 0; columnId < rowType.getFieldCount(); columnId++) {
if (includedColumns.containsKey(columnId)) {
StreamDescriptor streamDescriptor = streamDescriptors.get(columnId);
streamReaders[columnId] = StreamReaders.createStreamReader(streamDescriptor, hiveStorageTimeZone, systemMemoryContext);
streamReaders[columnId] = StreamReaders.createStreamReader(streamDescriptor, systemMemoryContext);
}
}
return streamReaders;
Expand Down
25 changes: 25 additions & 0 deletions presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import io.prestosql.spi.type.VarcharType;
import org.openjdk.jol.info.ClassLayout;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -102,6 +103,7 @@ public enum OrcWriteValidationMode
private final OrcEncoding orcEncoding;
private final List<Integer> version;
private final CompressionKind compression;
private final ZoneId timeZone;
private final int rowGroupMaxRowCount;
private final List<String> columnNames;
private final Map<String, Slice> metadata;
Expand All @@ -115,6 +117,7 @@ private OrcWriteValidation(
OrcEncoding orcEncoding,
List<Integer> version,
CompressionKind compression,
ZoneId timeZone,
int rowGroupMaxRowCount,
List<String> columnNames,
Map<String, Slice> metadata,
Expand All @@ -127,6 +130,7 @@ private OrcWriteValidation(
this.orcEncoding = orcEncoding;
this.version = version;
this.compression = compression;
this.timeZone = timeZone;
this.rowGroupMaxRowCount = rowGroupMaxRowCount;
this.columnNames = columnNames;
this.metadata = metadata;
Expand Down Expand Up @@ -157,6 +161,20 @@ public CompressionKind getCompression()
return compression;
}

public ZoneId getTimeZone()
{
return timeZone;
}

public void validateTimeZone(OrcDataSourceId orcDataSourceId, ZoneId actualTimeZone)
throws OrcCorruptionException
{
// DWRF does not store the writer time zone
if (!isDwrf() && !timeZone.equals(actualTimeZone)) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected time zone");
}
}

public int getRowGroupMaxRowCount()
{
return rowGroupMaxRowCount;
Expand Down Expand Up @@ -852,6 +870,7 @@ public static class OrcWriteValidationBuilder

private List<Integer> version;
private CompressionKind compression;
private ZoneId timeZone;
private int rowGroupMaxRowCount;
private int stringStatisticsLimitInBytes;
private List<String> columnNames;
Expand Down Expand Up @@ -886,6 +905,11 @@ public void setCompression(CompressionKind compression)
this.compression = compression;
}

public void setTimeZone(ZoneId timeZone)
{
this.timeZone = timeZone;
}

public void setRowGroupMaxRowCount(int rowGroupMaxRowCount)
{
this.rowGroupMaxRowCount = rowGroupMaxRowCount;
Expand Down Expand Up @@ -952,6 +976,7 @@ public OrcWriteValidation build()
orcEncoding,
version,
compression,
timeZone,
rowGroupMaxRowCount,
columnNames,
metadata,
Expand Down
6 changes: 5 additions & 1 deletion presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@

import java.io.Closeable;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -139,6 +141,7 @@ public OrcWriter(
this.orcEncoding = requireNonNull(orcEncoding, "orcEncoding is null");
this.compression = requireNonNull(compression, "compression is null");
recordValidation(validation -> validation.setCompression(compression));
recordValidation(validation -> validation.setTimeZone(hiveStorageTimeZone.toTimeZone().toZoneId()));

requireNonNull(options, "options is null");
checkArgument(options.getStripeMaxSize().compareTo(options.getStripeMinSize()) >= 0, "stripeMaxSize must be greater than stripeMinSize");
Expand Down Expand Up @@ -411,7 +414,8 @@ private List<OrcDataOutput> bufferStripeData(long stripeStartOffset, FlushReason
columnStatistics.put(0, new ColumnStatistics((long) stripeRowCount, 0, null, null, null, null, null, null, null, null));

// add footer
StripeFooter stripeFooter = new StripeFooter(allStreams, toDenseList(columnEncodings, orcTypes.size()));
Optional<ZoneId> timeZone = Optional.of(hiveStorageTimeZone.toTimeZone().toZoneId());
StripeFooter stripeFooter = new StripeFooter(allStreams, toDenseList(columnEncodings, orcTypes.size()), timeZone);
Slice footer = metadataWriter.writeStripeFooter(stripeFooter);
outputData.add(createDataOutput(footer));

Expand Down
11 changes: 10 additions & 1 deletion presto-orc/src/main/java/io/prestosql/orc/Stripe.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.prestosql.orc.metadata.ColumnEncoding;
import io.prestosql.orc.stream.InputStreamSources;

import java.time.ZoneId;
import java.util.List;

import static com.google.common.base.MoreObjects.toStringHelper;
Expand All @@ -25,13 +26,15 @@
public class Stripe
{
private final long rowCount;
private final ZoneId timeZone;
private final List<ColumnEncoding> columnEncodings;
private final List<RowGroup> rowGroups;
private final InputStreamSources dictionaryStreamSources;

public Stripe(long rowCount, List<ColumnEncoding> columnEncodings, List<RowGroup> rowGroups, InputStreamSources dictionaryStreamSources)
public Stripe(long rowCount, ZoneId timeZone, List<ColumnEncoding> columnEncodings, List<RowGroup> rowGroups, InputStreamSources dictionaryStreamSources)
{
this.rowCount = rowCount;
this.timeZone = requireNonNull(timeZone, "timeZone is null");
this.columnEncodings = requireNonNull(columnEncodings, "columnEncodings is null");
this.rowGroups = ImmutableList.copyOf(requireNonNull(rowGroups, "rowGroups is null"));
this.dictionaryStreamSources = requireNonNull(dictionaryStreamSources, "dictionaryStreamSources is null");
Expand All @@ -42,6 +45,11 @@ public long getRowCount()
return rowCount;
}

public ZoneId getTimeZone()
{
return timeZone;
}

public List<ColumnEncoding> getColumnEncodings()
{
return columnEncodings;
Expand All @@ -62,6 +70,7 @@ public String toString()
{
return toStringHelper(this)
.add("rowCount", rowCount)
.add("timeZone", timeZone)
.add("columnEncodings", columnEncodings)
.add("rowGroups", rowGroups)
.add("dictionaryStreams", dictionaryStreamSources)
Expand Down
12 changes: 10 additions & 2 deletions presto-orc/src/main/java/io/prestosql/orc/StripeReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -75,6 +76,7 @@
public class StripeReader
{
private final OrcDataSource orcDataSource;
private final ZoneId defaultTimeZone;
private final Optional<OrcDecompressor> decompressor;
private final List<OrcType> types;
private final HiveWriterVersion hiveWriterVersion;
Expand All @@ -85,6 +87,7 @@ public class StripeReader
private final Optional<OrcWriteValidation> writeValidation;

public StripeReader(OrcDataSource orcDataSource,
ZoneId defaultTimeZone,
Optional<OrcDecompressor> decompressor,
List<OrcType> types,
Set<Integer> includedColumns,
Expand All @@ -95,6 +98,7 @@ public StripeReader(OrcDataSource orcDataSource,
Optional<OrcWriteValidation> writeValidation)
{
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
this.defaultTimeZone = requireNonNull(defaultTimeZone, "defaultTimeZone is null");
this.decompressor = requireNonNull(decompressor, "decompressor is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.includedOrcColumns = getIncludedOrcColumns(types, requireNonNull(includedColumns, "includedColumns is null"));
Expand All @@ -111,6 +115,10 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
// read the stripe footer
StripeFooter stripeFooter = readStripeFooter(stripe, systemMemoryUsage);
List<ColumnEncoding> columnEncodings = stripeFooter.getColumnEncodings();
if (writeValidation.isPresent()) {
writeValidation.get().validateTimeZone(orcDataSource.getId(), stripeFooter.getTimeZone().orElse(null));
}
ZoneId timeZone = stripeFooter.getTimeZone().orElse(defaultTimeZone);

// get streams for selected columns
Map<StreamId, Stream> streams = new HashMap<>();
Expand Down Expand Up @@ -182,7 +190,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
selectedRowGroups,
columnEncodings);

return new Stripe(stripe.getNumberOfRows(), columnEncodings, rowGroups, dictionaryStreamSources);
return new Stripe(stripe.getNumberOfRows(), timeZone, columnEncodings, rowGroups, dictionaryStreamSources);
}
catch (InvalidCheckpointException e) {
// The ORC file contains a corrupt checkpoint stream
Expand Down Expand Up @@ -241,7 +249,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
}
RowGroup rowGroup = new RowGroup(0, 0, stripe.getNumberOfRows(), minAverageRowBytes, new InputStreamSources(builder.build()));

return new Stripe(stripe.getNumberOfRows(), columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
return new Stripe(stripe.getNumberOfRows(), timeZone, columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
}

public Map<StreamId, OrcInputStream> readDiskRanges(long stripeOffset, Map<StreamId, DiskRange> diskRanges, AggregatedMemoryContext systemMemoryUsage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public StripeFooter readStripeFooter(List<OrcType> types, InputStream inputStrea
{
CodedInputStream input = CodedInputStream.newInstance(inputStream);
DwrfProto.StripeFooter stripeFooter = DwrfProto.StripeFooter.parseFrom(input);
return new StripeFooter(toStream(stripeFooter.getStreamsList()), toColumnEncoding(types, stripeFooter.getColumnsList()));
return new StripeFooter(toStream(stripeFooter.getStreamsList()), toColumnEncoding(types, stripeFooter.getColumnsList()), Optional.empty());
}

private static Stream toStream(DwrfProto.Stream stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public StripeFooter readStripeFooter(List<OrcType> types, InputStream inputStrea
try {
return delegate.readStripeFooter(types, inputStream);
}
catch (IOException e) {
catch (IOException | RuntimeException e) {
throw propagate(e, "Invalid stripe footer");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.airlift.slice.SliceUtf8.tryGetCodePointAt;
Expand Down Expand Up @@ -159,7 +161,11 @@ public StripeFooter readStripeFooter(List<OrcType> types, InputStream inputStrea
{
CodedInputStream input = CodedInputStream.newInstance(inputStream);
OrcProto.StripeFooter stripeFooter = OrcProto.StripeFooter.parseFrom(input);
return new StripeFooter(toStream(stripeFooter.getStreamsList()), toColumnEncoding(stripeFooter.getColumnsList()));
return new StripeFooter(
toStream(stripeFooter.getStreamsList()),
toColumnEncoding(stripeFooter.getColumnsList()),
Optional.ofNullable(emptyToNull(stripeFooter.getWriterTimezone()))
.map(zone -> TimeZone.getTimeZone(zone).toZoneId()));
}

private static Stream toStream(OrcProto.Stream stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@

import java.io.IOException;
import java.io.OutputStream;
import java.time.ZoneId;
import java.util.List;
import java.util.Map.Entry;
import java.util.TimeZone;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.toIntExact;
Expand Down Expand Up @@ -268,13 +270,16 @@ private static UserMetadataItem toUserMetadata(Entry<String, Slice> entry)
public int writeStripeFooter(SliceOutput output, StripeFooter footer)
throws IOException
{
ZoneId zone = footer.getTimeZone().orElseThrow(() -> new IllegalArgumentException("Time zone not set"));

OrcProto.StripeFooter footerProtobuf = OrcProto.StripeFooter.newBuilder()
.addAllStreams(footer.getStreams().stream()
.map(OrcMetadataWriter::toStream)
.collect(toList()))
.addAllColumns(footer.getColumnEncodings().stream()
.map(OrcMetadataWriter::toColumnEncoding)
.collect(toList()))
.setWriterTimezone(TimeZone.getTimeZone(zone).getID())
.build();

return writeProtobufObject(output, footerProtobuf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@

import com.google.common.collect.ImmutableList;

import java.time.ZoneId;
import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public class StripeFooter
{
private final List<Stream> streams;
private final List<ColumnEncoding> columnEncodings;
private final Optional<ZoneId> timeZone;

public StripeFooter(List<Stream> streams, List<ColumnEncoding> columnEncodings)
public StripeFooter(List<Stream> streams, List<ColumnEncoding> columnEncodings, Optional<ZoneId> timeZone)
{
this.streams = ImmutableList.copyOf(requireNonNull(streams, "streams is null"));
this.columnEncodings = ImmutableList.copyOf(requireNonNull(columnEncodings, "columnEncodings is null"));
this.timeZone = requireNonNull(timeZone, "timeZone is null");
}

public List<ColumnEncoding> getColumnEncodings()
Expand All @@ -39,4 +43,9 @@ public List<Stream> getStreams()
{
return streams;
}

public Optional<ZoneId> getTimeZone()
{
return timeZone;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import javax.annotation.Nullable;

import java.io.IOException;
import java.time.ZoneId;
import java.util.List;

import static com.google.common.base.MoreObjects.toStringHelper;
Expand Down Expand Up @@ -130,7 +131,7 @@ private void openRowGroup()
}

@Override
public void startStripe(InputStreamSources dictionaryStreamSources, List<ColumnEncoding> encoding)
public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List<ColumnEncoding> encoding)
{
presentStreamSource = missingStreamSource(BooleanInputStream.class);
dataStreamSource = missingStreamSource(BooleanInputStream.class);
Expand Down
Loading

0 comments on commit f086b52

Please sign in to comment.