Skip to content

Commit

Permalink
ORC-991: Fix the bug of encrypted column read crash (#905)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

1. RowIndex is never marked as encrypted.  The StreamName constructor adds encryption to make rowIndex encrypted when needed.

TreeWriterBase.java
```java
  public void writeStripe(int requiredIndexEntries) throws IOException {
      .....
-    context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex);
+    context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX, encryption), rowIndex);
      .....
  }
```

2.  findStreams in StripePlanner.java total offset order and write inconsistency.
finalizeStripe In PhysicalFsWriter.java
> 1. write the unencrypted index streams 
> 2. write the encrypted index streams 
> 3. write the unencrypted data streams
> 4. write the encrypted data streams
```java
  @OverRide
  public void finalizeStripe(OrcProto.StripeFooter.Builder footerBuilder,
                             OrcProto.StripeInformation.Builder dirEntry
                             ) throws IOException {
    .....
    // write the unencrypted index streams
    unencrypted.writeStreams(StreamName.Area.INDEX, rawWriter);
    // write the encrypted index streams
    for (VariantTracker variant: variants.values()) {
      variant.writeStreams(StreamName.Area.INDEX, rawWriter);
    }

    // write the unencrypted data streams
    unencrypted.writeStreams(StreamName.Area.DATA, rawWriter);
    // write out the encrypted data streams
    for (VariantTracker variant: variants.values()) {
      variant.writeStreams(StreamName.Area.DATA, rawWriter);
    }
    .....
  }
```

findStreams in StripePlanner.java  
> 1. total offset the unencrypted index/data streams 
> 2. total offset encrypted index/data streams 

```java
  private void findStreams(long streamStart,
                           OrcProto.StripeFooter footer,
                           boolean[] columnInclude) throws IOException {
    long currentOffset = streamStart;
    Arrays.fill(bloomFilterKinds, null);
    for(OrcProto.Stream stream: footer.getStreamsList()) {
      currentOffset += handleStream(currentOffset, columnInclude, stream, null);
    }

    // Add the encrypted streams that we are using
    for(ReaderEncryptionVariant variant: encryption.getVariants()) {
      int variantId = variant.getVariantId();
      OrcProto.StripeEncryptionVariant stripeVariant =
          footer.getEncryption(variantId);
      for(OrcProto.Stream stream: stripeVariant.getStreamsList()) {
        currentOffset += handleStream(currentOffset, columnInclude, stream, variant);
      }
    }
  }
```


Causes misalignment of data reading.
This pr ensures that the read offset is consistent with the write.


3. Fix Decimal64TreeWriter stream not binding encryption.

### Why are the changes needed?

Fix the bug of encrypted column read crash.

### How was this patch tested?

Added unit test for reading encrypted columns.
  • Loading branch information
guiyanakuang authored Sep 12, 2021
1 parent ebf33dc commit 792c3f8
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 6 deletions.
34 changes: 30 additions & 4 deletions java/core/src/java/org/apache/orc/impl/reader/StripePlanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.orc.impl.CryptoUtils;
import org.apache.orc.impl.InStream;
import org.apache.orc.impl.OrcIndex;
import org.apache.orc.impl.PhysicalFsWriter;
import org.apache.orc.impl.RecordReaderUtils;
import org.apache.orc.impl.StreamName;
import org.apache.orc.impl.reader.tree.TypeReader;
Expand Down Expand Up @@ -273,18 +274,20 @@ private void buildEncodings(OrcProto.StripeFooter footer,
* @param offset the position in the file for this stream
* @param columnInclude which columns are being read
* @param stream the stream to consider
* @param area only the area will be included
* @param variant the variant being read
* @return the offset for the next stream
*/
private long handleStream(long offset,
boolean[] columnInclude,
OrcProto.Stream stream,
StreamName.Area area,
ReaderEncryptionVariant variant) {
int column = stream.getColumn();
if (stream.hasKind()) {
OrcProto.Stream.Kind kind = stream.getKind();

if (kind == OrcProto.Stream.Kind.ENCRYPTED_INDEX ||
if (StreamName.getArea(kind) != area || kind == OrcProto.Stream.Kind.ENCRYPTED_INDEX ||
kind == OrcProto.Stream.Kind.ENCRYPTED_DATA) {
// Ignore the placeholders that shouldn't count toward moving the
// offsets.
Expand Down Expand Up @@ -323,6 +326,8 @@ private long handleStream(long offset,

/**
* Find the complete list of streams.
* CurrentOffset total order must be consistent with write
* {@link PhysicalFsWriter#finalizeStripe}
* @param streamStart the starting offset of streams in the file
* @param footer the footer for the stripe
* @param columnInclude which columns are being read
Expand All @@ -332,19 +337,40 @@ private void findStreams(long streamStart,
boolean[] columnInclude) throws IOException {
long currentOffset = streamStart;
Arrays.fill(bloomFilterKinds, null);
// +-----------------+---------------+-----------------+---------------+
// | | | | |
// | unencrypted | encrypted | unencrypted | encrypted |
// | index | index | data | data |
// | | | | |
// +-----------------+---------------+-----------------+---------------+
// Storage layout of index and data, So we need to find the streams in this order
//
// find index streams, encrypted first and then unencrypted
currentOffset = findStreamsByArea(currentOffset, footer, StreamName.Area.INDEX, columnInclude);

// find data streams, encrypted first and then unencrypted
findStreamsByArea(currentOffset, footer, StreamName.Area.DATA, columnInclude);
}

private long findStreamsByArea(long currentOffset,
OrcProto.StripeFooter footer,
StreamName.Area area,
boolean[] columnInclude) {
// find unencrypted streams
for(OrcProto.Stream stream: footer.getStreamsList()) {
currentOffset += handleStream(currentOffset, columnInclude, stream, null);
currentOffset += handleStream(currentOffset, columnInclude, stream, area, null);
}

// Add the encrypted streams that we are using
// find encrypted streams
for(ReaderEncryptionVariant variant: encryption.getVariants()) {
int variantId = variant.getVariantId();
OrcProto.StripeEncryptionVariant stripeVariant =
footer.getEncryption(variantId);
for(OrcProto.Stream stream: stripeVariant.getStreamsList()) {
currentOffset += handleStream(currentOffset, columnInclude, stream, variant);
currentOffset += handleStream(currentOffset, columnInclude, stream, area, variant);
}
}
return currentOffset;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public Decimal64TreeWriter(TypeDescription schema,
WriterContext writer) throws IOException {
super(schema, encryption, writer);
OutStream stream = writer.createStream(
new StreamName(id, OrcProto.Stream.Kind.DATA));
new StreamName(id, OrcProto.Stream.Kind.DATA, encryption));
// Use RLEv2 until we have the new RLEv3.
valueWriter = new RunLengthIntegerWriterV2(stream, true, true);
scale = schema.getScale();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ public void writeStripe(int requiredIndexEntries) throws IOException {
"index entries found: " + rowIndex.getEntryCount() + " expected: " +
requiredIndexEntries);
}
context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex);
context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX, encryption), rowIndex);
rowIndex.clear();
rowIndexEntry.clear();
}
Expand Down
143 changes: 143 additions & 0 deletions java/core/src/test/org/apache/orc/impl/TestEncryption.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.orc.impl;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf;
import org.apache.hadoop.hive.ql.io.sarg.SearchArgument;
import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory;
import org.apache.orc.EncryptionAlgorithm;
import org.apache.orc.InMemoryKeystore;
import org.apache.orc.OrcConf;
import org.apache.orc.OrcFile;
import org.apache.orc.Reader;
import org.apache.orc.RecordReader;
import org.apache.orc.TypeDescription;
import org.apache.orc.Writer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class TestEncryption {

Path workDir = new Path(System.getProperty("test.tmp.dir"));
Configuration conf;
FileSystem fs;
Path testFilePath;
TypeDescription schema;
KeyProvider keyProvider;
String encryption;
String mask;

@BeforeEach
public void openFileSystem() throws Exception {
conf = new Configuration();
conf.setInt(OrcConf.ROW_INDEX_STRIDE.getAttribute(), VectorizedRowBatch.DEFAULT_SIZE);
fs = FileSystem.getLocal(conf);
fs.setWorkingDirectory(workDir);
testFilePath = new Path("testWriterImpl.orc");
fs.create(testFilePath, true);
schema = TypeDescription.fromString("struct<id:int,name:string>");
byte[] kmsKey = "secret123".getBytes(StandardCharsets.UTF_8);
keyProvider = new InMemoryKeystore()
.addKey("pii", EncryptionAlgorithm.AES_CTR_128, kmsKey);
encryption = "pii:id,name";
mask = "sha256:id,name";
}

@AfterEach
public void deleteTestFile() throws Exception {
fs.delete(testFilePath, false);
}

private void write() throws IOException {
Writer writer = OrcFile.createWriter(testFilePath,
OrcFile.writerOptions(conf)
.setSchema(schema)
.overwrite(true)
.setKeyProvider(keyProvider)
.encrypt(encryption)
.masks(mask));
VectorizedRowBatch batch = schema.createRowBatch();
LongColumnVector id = (LongColumnVector) batch.cols[0];
BytesColumnVector name = (BytesColumnVector) batch.cols[1];
for (int r = 0; r < VectorizedRowBatch.DEFAULT_SIZE * 2; ++r) {
int row = batch.size++;
id.vector[row] = r;
byte[] buffer = ("name-" + (r * 3)).getBytes(StandardCharsets.UTF_8);
name.setRef(row, buffer, 0, buffer.length);
if (batch.size == batch.getMaxSize()) {
writer.addRowBatch(batch);
batch.reset();
}
}
if (batch.size != 0) {
writer.addRowBatch(batch);
}
writer.close();
}

private void read(boolean pushDown) throws IOException {
Reader reader = OrcFile.createReader(testFilePath,
OrcFile.readerOptions(conf).setKeyProvider(keyProvider));
SearchArgument searchArgument = pushDown ? SearchArgumentFactory.newBuilder()
.equals("id", PredicateLeaf.Type.LONG, (long) VectorizedRowBatch.DEFAULT_SIZE)
.build() : null;
VectorizedRowBatch batch = schema.createRowBatch();
Reader.Options options = reader.options().schema(this.schema);
if (pushDown) {
options = options.searchArgument(searchArgument, new String[]{"id"});
}
RecordReader rowIterator = reader.rows(options);
LongColumnVector idColumn = (LongColumnVector) batch.cols[0];
BytesColumnVector nameColumn = (BytesColumnVector) batch.cols[1];
int batchNum = pushDown ? 1 : 0;
while (rowIterator.nextBatch(batch)) {
for (int row = 0; row < batch.size; ++row) {
long value = row + ((long) batchNum * VectorizedRowBatch.DEFAULT_SIZE);
assertEquals(value, idColumn.vector[row]);
assertEquals("name-" + (value * 3), nameColumn.toString(row));
}
batchNum ++;
}
rowIterator.close();
}

@Test
public void testReadEncryption() throws IOException {
write();
read(false);
}

@Test
public void testPushDownReadEncryption() throws IOException {
write();
read(true);
}

}

0 comments on commit 792c3f8

Please sign in to comment.