diff --git a/lib/trino-filesystem/pom.xml b/lib/trino-filesystem/pom.xml index 58d99e036459..c367c96d11f2 100644 --- a/lib/trino-filesystem/pom.xml +++ b/lib/trino-filesystem/pom.xml @@ -188,4 +188,18 @@ test + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + about.html + iceberg-build.properties + + + + + diff --git a/lib/trino-orc/pom.xml b/lib/trino-orc/pom.xml index f496879675dd..b13b977c5b02 100644 --- a/lib/trino-orc/pom.xml +++ b/lib/trino-orc/pom.xml @@ -208,6 +208,16 @@ + + org.basepom.maven + duplicate-finder-maven-plugin + + + about.html + iceberg-build.properties + + + diff --git a/lib/trino-parquet/pom.xml b/lib/trino-parquet/pom.xml index 830d2ec1f5e4..10ce60385bf1 100644 --- a/lib/trino-parquet/pom.xml +++ b/lib/trino-parquet/pom.xml @@ -13,12 +13,25 @@ Trino - Parquet file format support + + com.fasterxml.jackson.core + jackson-core + jar + compile + + + + com.fasterxml.jackson.core + jackson-databind + jar + compile + + com.google.errorprone error_prone_annotations true - com.google.guava guava @@ -29,6 +42,11 @@ aircompressor-v3 + + io.airlift + json + + io.airlift log @@ -95,6 +113,12 @@ + + io.trino + trino-filesystem + provided + + io.trino trino-spi diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java index bbece17c9b7e..8eaebc60c93d 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java @@ -21,12 +21,14 @@ public abstract sealed class DataPage { protected final int valueCount; private final OptionalLong firstRowIndex; + private final int pageIndex; - public DataPage(int uncompressedSize, int valueCount, OptionalLong firstRowIndex) + public DataPage(int uncompressedSize, int valueCount, OptionalLong firstRowIndex, int pageIndex) { super(uncompressedSize); this.valueCount = valueCount; this.firstRowIndex = firstRowIndex; + this.pageIndex = pageIndex; } /** @@ -41,4 +43,9 @@ public int getValueCount() { return valueCount; } + + public int getPageIndex() + { + return pageIndex; + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java index b0895445d813..8dbf9809378d 100755 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java @@ -35,15 +35,17 @@ public DataPageV1( OptionalLong firstRowIndex, ParquetEncoding repetitionLevelEncoding, ParquetEncoding definitionLevelEncoding, - ParquetEncoding valuesEncoding) + ParquetEncoding valuesEncoding, + int pageIndex) { - super(uncompressedSize, valueCount, firstRowIndex); + super(uncompressedSize, valueCount, firstRowIndex, pageIndex); this.slice = requireNonNull(slice, "slice is null"); this.repetitionLevelEncoding = repetitionLevelEncoding; this.definitionLevelEncoding = definitionLevelEncoding; this.valuesEncoding = valuesEncoding; } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java index b0cbfd9ed8fc..6544942e74eb 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java @@ -44,9 +44,10 @@ public DataPageV2( int uncompressedSize, OptionalLong firstRowIndex, Statistics> statistics, - boolean isCompressed) + boolean isCompressed, + int pageIndex) { - super(uncompressedSize, valueCount, firstRowIndex); + super(uncompressedSize, valueCount, firstRowIndex, pageIndex); this.rowCount = rowCount; this.nullCount = nullCount; this.repetitionLevels = requireNonNull(repetitionLevels, "repetitionLevels slice is null"); @@ -82,6 +83,7 @@ public ParquetEncoding getDataEncoding() return dataEncoding; } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java index 74fdf540199d..bd92d7fc0c8e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java @@ -43,6 +43,7 @@ public DictionaryPage(Slice slice, int uncompressedSize, int dictionarySize, Par encoding); } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java new file mode 100644 index 000000000000..26cac6e1a783 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.TrinoCryptoConfigurationUtil; +import io.trino.parquet.crypto.TrinoDecryptionPropertiesFactory; + +import java.lang.reflect.InvocationTargetException; +import java.util.Optional; + +public class EncryptionUtils +{ + public static final Logger LOG = Logger.get(EncryptionUtils.class); + + private EncryptionUtils() {} + + public static Optional createDecryptor(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + { + if (parquetReaderOptions == null || filePath == null || trinoFileSystem == null) { + return Optional.empty(); + } + + Optional cryptoFactory = loadDecryptionPropertiesFactory(parquetReaderOptions); + Optional fileDecryptionProperties = cryptoFactory.map(factory -> factory.getFileDecryptionProperties(parquetReaderOptions, filePath, trinoFileSystem)); + return fileDecryptionProperties.map(properties -> new InternalFileDecryptor(properties)); + } + + private static Optional loadDecryptionPropertiesFactory(ParquetReaderOptions trinoParquetCryptoConfig) + { + if (trinoParquetCryptoConfig.getCryptoFactoryClass() == null) { + return Optional.empty(); + } + final Class> foundClass = TrinoCryptoConfigurationUtil.getClassFromConfig( + trinoParquetCryptoConfig.getCryptoFactoryClass(), TrinoDecryptionPropertiesFactory.class); + + if (foundClass == null) { + return Optional.empty(); + } + + try { + return Optional.ofNullable((TrinoDecryptionPropertiesFactory) foundClass.getConstructor().newInstance()); + } + catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + LOG.warn("could not instantiate decryptionPropertiesFactoryClass class: " + foundClass, e); + return Optional.empty(); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java b/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java index 69cde62cf435..64b1f861717b 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java @@ -13,6 +13,8 @@ */ package io.trino.parquet; +import io.airlift.slice.Slice; + public abstract class Page { protected final int uncompressedSize; @@ -26,4 +28,6 @@ public int getUncompressedSize() { return uncompressedSize; } + + public abstract Slice getSlice(); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderEncryptionOptions.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderEncryptionOptions.java new file mode 100644 index 000000000000..854eeba0bace --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderEncryptionOptions.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import io.trino.parquet.crypto.keytools.TrinoKeyToolkit; +import io.trino.parquet.crypto.keytools.TrinoKmsClient; + +public class ParquetReaderEncryptionOptions +{ + final String cryptoFactoryClass; + final String encryptionKmsClientClass; + final String encryptionKmsInstanceId; + final String encryptionKmsInstanceUrl; + final String encryptionKeyAccessToken; + final long encryptionCacheLifetimeSeconds; + final boolean uniformEncryption; + boolean encryptionParameterChecked; + final String failsafeEncryptionKeyId; + final String columnKeys; + final String footerKeyId; + final String[] versionedKeyList; + final String keyFile; + final String[] keyList; + final boolean isEncryptionEnvironmentKeys; + + public ParquetReaderEncryptionOptions() + { + this.cryptoFactoryClass = null; + this.encryptionKmsClientClass = null; + this.encryptionKmsInstanceId = null; + this.encryptionKmsInstanceUrl = null; + this.encryptionKeyAccessToken = TrinoKmsClient.KEY_ACCESS_TOKEN_DEFAULT; + this.encryptionCacheLifetimeSeconds = TrinoKeyToolkit.CACHE_LIFETIME_DEFAULT_SECONDS; + this.uniformEncryption = false; + this.encryptionParameterChecked = false; + this.failsafeEncryptionKeyId = null; + this.footerKeyId = null; + this.columnKeys = null; + this.versionedKeyList = null; + this.keyFile = null; + this.keyList = null; + this.isEncryptionEnvironmentKeys = false; + } + + public ParquetReaderEncryptionOptions(String cryptoFactoryClass, + String encryptionKmsClientClass, + String encryptionKmsInstanceId, + String encryptionKmsInstanceUrl, + String encryptionKeyAccessToken, + long encryptionCacheLifetimeSeconds, + boolean uniformEncryption, + boolean encryptionParameterChecked, + String failsafeEncryptionKeyId, + String footerKeyId, + String columnKeys, + String[] versionedKeyList, + String keyFile, + String[] keyList, + boolean isEncryptionEnvironmentKeys) + { + this.cryptoFactoryClass = cryptoFactoryClass; + this.encryptionKmsClientClass = encryptionKmsClientClass; + this.encryptionKmsInstanceId = encryptionKmsInstanceId; + this.encryptionKmsInstanceUrl = encryptionKmsInstanceUrl; + this.encryptionKeyAccessToken = encryptionKeyAccessToken; + this.encryptionCacheLifetimeSeconds = encryptionCacheLifetimeSeconds; + this.uniformEncryption = uniformEncryption; + this.encryptionParameterChecked = encryptionParameterChecked; + this.failsafeEncryptionKeyId = failsafeEncryptionKeyId; + this.footerKeyId = footerKeyId; + this.columnKeys = columnKeys; + this.versionedKeyList = versionedKeyList; + this.keyFile = keyFile; + this.keyList = keyList; + this.isEncryptionEnvironmentKeys = isEncryptionEnvironmentKeys; + } + + public ParquetReaderEncryptionOptions withEncryptionKmsClientClass(String encryptionKmsClientClass) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + encryptionKmsClientClass, + this.encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withCryptoFactoryClass(String cryptoFactoryClass) + { + return new ParquetReaderEncryptionOptions(cryptoFactoryClass, + this.encryptionKmsClientClass, + this.encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withEncryptionKmsInstanceId(String encryptionKmsInstanceId) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + this.encryptionKmsClientClass, + encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withEncryptionKmsInstanceUrl(String encryptionKmsInstanceUrl) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + this.encryptionKmsClientClass, + this.encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withEncryptionKeyAccessToken(String encryptionKeyAccessToken) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + this.encryptionKmsClientClass, + this.encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withEncryptionCacheLifetimeSeconds(Long encryptionCacheLifetimeSeconds) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + this.encryptionKmsClientClass, + this.encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + this.keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } + + public ParquetReaderEncryptionOptions withEncryptionKeyFile(String keyFile) + { + return new ParquetReaderEncryptionOptions(this.cryptoFactoryClass, + this.encryptionKmsClientClass, + this.encryptionKmsInstanceId, + this.encryptionKmsInstanceUrl, + this.encryptionKeyAccessToken, + this.encryptionCacheLifetimeSeconds, + this.uniformEncryption, + this.encryptionParameterChecked, + this.failsafeEncryptionKeyId, + this.footerKeyId, + this.columnKeys, + this.versionedKeyList, + keyFile, + this.keyList, + this.isEncryptionEnvironmentKeys); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java index 364e718b71ad..67e7090e0234 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java @@ -36,6 +36,7 @@ public class ParquetReaderOptions private final boolean useBloomFilter; private final DataSize smallFileThreshold; private final boolean vectorizedDecodingEnabled; + private final ParquetReaderEncryptionOptions parquetReaderEncryptionOptions; public ParquetReaderOptions() { @@ -48,6 +49,7 @@ public ParquetReaderOptions() useBloomFilter = true; smallFileThreshold = DEFAULT_SMALL_FILE_THRESHOLD; vectorizedDecodingEnabled = true; + parquetReaderEncryptionOptions = new ParquetReaderEncryptionOptions(); } private ParquetReaderOptions( @@ -59,7 +61,8 @@ private ParquetReaderOptions( boolean useColumnIndex, boolean useBloomFilter, DataSize smallFileThreshold, - boolean vectorizedDecodingEnabled) + boolean vectorizedDecodingEnabled, + ParquetReaderEncryptionOptions parquetReaderEncryptionOptions) { this.ignoreStatistics = ignoreStatistics; this.maxReadBlockSize = requireNonNull(maxReadBlockSize, "maxReadBlockSize is null"); @@ -71,6 +74,7 @@ private ParquetReaderOptions( this.useBloomFilter = useBloomFilter; this.smallFileThreshold = requireNonNull(smallFileThreshold, "smallFileThreshold is null"); this.vectorizedDecodingEnabled = vectorizedDecodingEnabled; + this.parquetReaderEncryptionOptions = parquetReaderEncryptionOptions; } public boolean isIgnoreStatistics() @@ -118,6 +122,91 @@ public DataSize getSmallFileThreshold() return smallFileThreshold; } + public String getCryptoFactoryClass() + { + return parquetReaderEncryptionOptions.cryptoFactoryClass; + } + + public long getEncryptionCacheLifetimeSeconds() + { + return this.parquetReaderEncryptionOptions.encryptionCacheLifetimeSeconds; + } + + public String getEncryptionKeyAccessToken() + { + return this.parquetReaderEncryptionOptions.encryptionKeyAccessToken; + } + + public String getEncryptionKmsInstanceId() + { + return this.parquetReaderEncryptionOptions.encryptionKmsInstanceId; + } + + public String getEncryptionKmsInstanceUrl() + { + return this.parquetReaderEncryptionOptions.encryptionKmsInstanceUrl; + } + + public String getEncryptionKmsClientClass() + { + return this.parquetReaderEncryptionOptions.encryptionKmsClientClass; + } + + public boolean isUniformEncryption() + { + return parquetReaderEncryptionOptions.uniformEncryption; + } + + public boolean isEncryptionParameterChecked() + { + return parquetReaderEncryptionOptions.encryptionParameterChecked; + } + + public String getFailsafeEncryptionKeyId() + { + return parquetReaderEncryptionOptions.failsafeEncryptionKeyId; + } + + public String getEncryptionColumnKeys() + { + return parquetReaderEncryptionOptions.columnKeys; + } + + public String getEncryptionFooterKeyId() + { + return parquetReaderEncryptionOptions.footerKeyId; + } + + public String[] getEncryptionVersionedKeyList() + { + return parquetReaderEncryptionOptions.versionedKeyList; + } + + public String[] getEncryptionKeyList() + { + return parquetReaderEncryptionOptions.keyList; + } + + public String getEncryptionKeyFile() + { + return parquetReaderEncryptionOptions.keyFile; + } + + public boolean isEncryptionEnvironmentKeys() + { + return parquetReaderEncryptionOptions.isEncryptionEnvironmentKeys; + } + + public void setEncryptionParameterChecked(boolean encryptionParameterChecked) + { + parquetReaderEncryptionOptions.encryptionParameterChecked = encryptionParameterChecked; + } + + public ParquetReaderEncryptionOptions encryptionOptions() + { + return this.parquetReaderEncryptionOptions; + } + public ParquetReaderOptions withIgnoreStatistics(boolean ignoreStatistics) { return new ParquetReaderOptions( @@ -129,7 +218,8 @@ public ParquetReaderOptions withIgnoreStatistics(boolean ignoreStatistics) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withMaxReadBlockSize(DataSize maxReadBlockSize) @@ -143,7 +233,8 @@ public ParquetReaderOptions withMaxReadBlockSize(DataSize maxReadBlockSize) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withMaxReadBlockRowCount(int maxReadBlockRowCount) @@ -157,7 +248,8 @@ public ParquetReaderOptions withMaxReadBlockRowCount(int maxReadBlockRowCount) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) @@ -171,7 +263,8 @@ public ParquetReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withMaxBufferSize(DataSize maxBufferSize) @@ -185,7 +278,8 @@ public ParquetReaderOptions withMaxBufferSize(DataSize maxBufferSize) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withUseColumnIndex(boolean useColumnIndex) @@ -199,7 +293,8 @@ public ParquetReaderOptions withUseColumnIndex(boolean useColumnIndex) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withBloomFilter(boolean useBloomFilter) @@ -213,7 +308,8 @@ public ParquetReaderOptions withBloomFilter(boolean useBloomFilter) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withSmallFileThreshold(DataSize smallFileThreshold) @@ -227,7 +323,8 @@ public ParquetReaderOptions withSmallFileThreshold(DataSize smallFileThreshold) useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); } public ParquetReaderOptions withVectorizedDecodingEnabled(boolean vectorizedDecodingEnabled) @@ -241,6 +338,22 @@ public ParquetReaderOptions withVectorizedDecodingEnabled(boolean vectorizedDeco useColumnIndex, useBloomFilter, smallFileThreshold, - vectorizedDecodingEnabled); + vectorizedDecodingEnabled, + parquetReaderEncryptionOptions); + } + + public ParquetReaderOptions withEncryptionOption(ParquetReaderEncryptionOptions encryptionOptions) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + vectorizedDecodingEnabled, + encryptionOptions); } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AADPrefixVerifier.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AADPrefixVerifier.java new file mode 100644 index 000000000000..2cb3ae66b215 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AADPrefixVerifier.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public interface AADPrefixVerifier +{ + /** + * Verifies identity (AAD Prefix) of individual file, or of file collection in a data set. + * Must be thread-safe. + * + * @param aadPrefix AAD Prefix + * @throws ParquetCryptoRuntimeException Throw exception if AAD prefix is wrong. + */ + void verify(byte[] aadPrefix) + throws ParquetCryptoRuntimeException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipher.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipher.java new file mode 100644 index 000000000000..8952a6fa02bc --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipher.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.trino.parquet.crypto.ModuleCipherFactory.ModuleType; + +import javax.crypto.Cipher; +import javax.crypto.spec.SecretKeySpec; + +import java.security.SecureRandom; + +import static java.util.Objects.requireNonNull; + +public class AesCipher +{ + public static final int NONCE_LENGTH = 12; + public static final int GCM_TAG_LENGTH = 16; + protected static final int CTR_IV_LENGTH = 16; + protected static final int GCM_TAG_LENGTH_BITS = 8 * GCM_TAG_LENGTH; + protected static final int CHUNK_LENGTH = 4 * 1024; + protected static final int SIZE_LENGTH = ModuleCipherFactory.SIZE_LENGTH; + // NIST SP 800-38D section 8.3 specifies limit on AES GCM encryption operations with same key and random IV/nonce + protected static final long GCM_RANDOM_IV_SAME_KEY_MAX_OPS = 1L << 32; + // NIST SP 800-38A doesn't specify limit on AES CTR operations. + // However, Parquet uses a random IV (with 12-byte random nonce). To avoid repetition due to "birthday problem", + // setting a conservative limit equal to GCM's value for random IVs + protected static final long CTR_RANDOM_IV_SAME_KEY_MAX_OPS = GCM_RANDOM_IV_SAME_KEY_MAX_OPS; + static final int AAD_FILE_UNIQUE_LENGTH = 8; + protected final SecureRandom randomGenerator; + protected final byte[] localNonce; + protected SecretKeySpec aesKey; + protected Cipher cipher; + + AesCipher(AesMode mode, byte[] keyBytes) + { + requireNonNull(keyBytes, "key bytes cannot be null"); + boolean allZeroKey = true; + for (byte kb : keyBytes) { + if (kb != 0) { + allZeroKey = false; + break; + } + } + + if (allZeroKey) { + throw new IllegalArgumentException("All key bytes are zero"); + } + + aesKey = new SecretKeySpec(keyBytes, "AES"); + randomGenerator = new SecureRandom(); + localNonce = new byte[NONCE_LENGTH]; + } + + public static byte[] createModuleAAD( + byte[] fileAAD, ModuleType moduleType, int rowGroupOrdinal, int columnOrdinal, int pageOrdinal) + { + byte[] typeOrdinalBytes = new byte[1]; + typeOrdinalBytes[0] = moduleType.getValue(); + + if (ModuleType.Footer == moduleType) { + return concatByteArrays(fileAAD, typeOrdinalBytes); + } + + if (rowGroupOrdinal < 0) { + throw new IllegalArgumentException("Wrong row group ordinal: " + rowGroupOrdinal); + } + short shortRGOrdinal = (short) rowGroupOrdinal; + if (shortRGOrdinal != rowGroupOrdinal) { + throw new ParquetCryptoRuntimeException("Encrypted parquet files can't have " + "more than " + + Short.MAX_VALUE + " row groups: " + rowGroupOrdinal); + } + byte[] rowGroupOrdinalBytes = shortToBytesLE(shortRGOrdinal); + + if (columnOrdinal < 0) { + throw new IllegalArgumentException("Wrong column ordinal: " + columnOrdinal); + } + short shortColumOrdinal = (short) columnOrdinal; + if (shortColumOrdinal != columnOrdinal) { + throw new ParquetCryptoRuntimeException("Encrypted parquet files can't have " + "more than " + + Short.MAX_VALUE + " columns: " + columnOrdinal); + } + byte[] columnOrdinalBytes = shortToBytesLE(shortColumOrdinal); + + if (ModuleType.DataPage != moduleType && ModuleType.DataPageHeader != moduleType) { + return concatByteArrays(fileAAD, typeOrdinalBytes, rowGroupOrdinalBytes, columnOrdinalBytes); + } + + if (pageOrdinal < 0) { + throw new IllegalArgumentException("Wrong page ordinal: " + pageOrdinal); + } + short shortPageOrdinal = (short) pageOrdinal; + if (shortPageOrdinal != pageOrdinal) { + throw new ParquetCryptoRuntimeException("Encrypted parquet files can't have " + "more than " + + Short.MAX_VALUE + " pages per chunk: " + pageOrdinal); + } + byte[] pageOrdinalBytes = shortToBytesLE(shortPageOrdinal); + + return concatByteArrays(fileAAD, typeOrdinalBytes, rowGroupOrdinalBytes, columnOrdinalBytes, pageOrdinalBytes); + } + + public static byte[] createFooterAAD(byte[] aadPrefixBytes) + { + return createModuleAAD(aadPrefixBytes, ModuleType.Footer, -1, -1, -1); + } + + // Update last two bytes with new page ordinal (instead of creating new page AAD from scratch) + public static void quickUpdatePageAAD(byte[] pageAAD, int newPageOrdinal) + { + requireNonNull(pageAAD, "pageAAD cannot be null"); + if (newPageOrdinal < 0) { + throw new IllegalArgumentException("Wrong page ordinal: " + newPageOrdinal); + } + short shortPageOrdinal = (short) newPageOrdinal; + if (shortPageOrdinal != newPageOrdinal) { + throw new ParquetCryptoRuntimeException("Encrypted parquet files can't have " + "more than " + + Short.MAX_VALUE + " pages per chunk: " + newPageOrdinal); + } + + byte[] pageOrdinalBytes = shortToBytesLE(shortPageOrdinal); + System.arraycopy(pageOrdinalBytes, 0, pageAAD, pageAAD.length - 2, 2); + } + + static byte[] concatByteArrays(byte[]... arrays) + { + int totalLength = 0; + for (byte[] array : arrays) { + totalLength += array.length; + } + + byte[] output = new byte[totalLength]; + int offset = 0; + for (byte[] array : arrays) { + System.arraycopy(array, 0, output, offset, array.length); + offset += array.length; + } + + return output; + } + + private static byte[] shortToBytesLE(short input) + { + byte[] output = new byte[2]; + output[1] = (byte) (0xff & (input >> 8)); + output[0] = (byte) (0xff & input); + + return output; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java new file mode 100644 index 000000000000..84422d223499 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.Cipher; +import javax.crypto.spec.IvParameterSpec; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +public class AesCtrDecryptor + extends AesCipher implements BlockCipher.Decryptor +{ + private final byte[] ctrIV; + + AesCtrDecryptor(byte[] keyBytes) + { + super(AesMode.CTR, keyBytes); + + try { + cipher = Cipher.getInstance(AesMode.CTR.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to create CTR cipher", e); + } + ctrIV = new byte[CTR_IV_LENGTH]; + // Setting last bit of initial CTR counter to 1 + ctrIV[CTR_IV_LENGTH - 1] = (byte) 1; + } + + @Override + public byte[] decrypt(byte[] lengthAndCiphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = lengthAndCiphertext.length - SIZE_LENGTH; + + return decrypt(lengthAndCiphertext, cipherTextOffset, cipherTextLength, aad); + } + + public byte[] decrypt(byte[] ciphertext, int cipherTextOffset, int cipherTextLength, byte[] aad) + { + int plainTextLength = cipherTextLength - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoRuntimeException("Wrong input length " + plainTextLength); + } + + // Get the nonce from ciphertext + System.arraycopy(ciphertext, cipherTextOffset, ctrIV, 0, NONCE_LENGTH); + + byte[] plainText = new byte[plainTextLength]; + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + int outputOffset = 0; + try { + IvParameterSpec spec = new IvParameterSpec(ctrIV); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + + // Breaking decryption into multiple updates, to trigger h/w acceleration in Java 9+ + while (inputLength > CHUNK_LENGTH) { + int written = cipher.update(ciphertext, inputOffset, CHUNK_LENGTH, plainText, outputOffset); + inputOffset += CHUNK_LENGTH; + outputOffset += written; + inputLength -= CHUNK_LENGTH; + } + + cipher.doFinal(ciphertext, inputOffset, inputLength, plainText, outputOffset); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to decrypt", e); + } + + return plainText; + } + + public ByteBuffer decrypt(ByteBuffer ciphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = ciphertext.limit() - ciphertext.position() - SIZE_LENGTH; + + int plainTextLength = cipherTextLength - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoRuntimeException("Wrong input length " + plainTextLength); + } + + // skip size + ciphertext.position(ciphertext.position() + cipherTextOffset); + // Get the nonce from ciphertext + ciphertext.get(ctrIV, 0, NONCE_LENGTH); + + // Reuse the input buffer as the output buffer + ByteBuffer plainText = ciphertext.slice(); + plainText.limit(plainTextLength); + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + try { + IvParameterSpec spec = new IvParameterSpec(ctrIV); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + + // Breaking decryption into multiple updates, to trigger h/w acceleration in Java 9+ + while (inputLength > CHUNK_LENGTH) { + ciphertext.position(inputOffset); + ciphertext.limit(inputOffset + CHUNK_LENGTH); + cipher.update(ciphertext, plainText); + inputOffset += CHUNK_LENGTH; + inputLength -= CHUNK_LENGTH; + } + ciphertext.position(inputOffset); + ciphertext.limit(inputOffset + inputLength); + cipher.doFinal(ciphertext, plainText); + plainText.flip(); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to decrypt", e); + } + + return plainText; + } + + @Override + public byte[] decrypt(InputStream from, byte[] aad) + throws IOException + { + byte[] lengthBuffer = new byte[SIZE_LENGTH]; + int gotBytes = 0; + + // Read the length of encrypted Thrift structure + while (gotBytes < SIZE_LENGTH) { + int n = from.read(lengthBuffer, gotBytes, SIZE_LENGTH - gotBytes); + if (n <= 0) { + throw new IOException("Tried to read int (4 bytes), but only got " + gotBytes + " bytes."); + } + gotBytes += n; + } + + final int ciphertextLength = ((lengthBuffer[3] & 0xff) << 24) + | ((lengthBuffer[2] & 0xff) << 16) + | ((lengthBuffer[1] & 0xff) << 8) + | ((lengthBuffer[0] & 0xff)); + + if (ciphertextLength < 1) { + throw new IOException("Wrong length of encrypted metadata: " + ciphertextLength); + } + + // Read the encrypted structure contents + byte[] ciphertextBuffer = new byte[ciphertextLength]; + gotBytes = 0; + while (gotBytes < ciphertextLength) { + int n = from.read(ciphertextBuffer, gotBytes, ciphertextLength - gotBytes); + if (n <= 0) { + throw new IOException( + "Tried to read " + ciphertextLength + " bytes, but only got " + gotBytes + " bytes."); + } + gotBytes += n; + } + + // Decrypt the structure contents + return decrypt(ciphertextBuffer, 0, ciphertextLength, aad); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrEncryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrEncryptor.java new file mode 100644 index 000000000000..a5942e344efd --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrEncryptor.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.Cipher; +import javax.crypto.spec.IvParameterSpec; + +import java.security.GeneralSecurityException; + +public class AesCtrEncryptor + extends AesCipher + implements BlockCipher.Encryptor +{ + private final byte[] ctrIV; + private long operationCounter; + + AesCtrEncryptor(byte[] keyBytes) + { + super(AesMode.CTR, keyBytes); + operationCounter = 0; + + try { + cipher = Cipher.getInstance(AesMode.CTR.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to create CTR cipher", e); + } + + ctrIV = new byte[CTR_IV_LENGTH]; + // Setting last bit of initial CTR counter to 1 + ctrIV[CTR_IV_LENGTH - 1] = (byte) 1; + } + + @Override + public byte[] encrypt(byte[] plainText, byte[] aad) + { + return encrypt(true, plainText, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] aad) + { + randomGenerator.nextBytes(localNonce); + return encrypt(writeLength, plainText, localNonce, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] nonce, byte[] aad) + { + if (operationCounter > CTR_RANDOM_IV_SAME_KEY_MAX_OPS) { + throw new ParquetCryptoRuntimeException( + "Exceeded limit of AES CTR encryption operations with same key and random IV"); + } + operationCounter++; + + if (nonce.length != NONCE_LENGTH) { + throw new ParquetCryptoRuntimeException("Wrong nonce length " + nonce.length); + } + int plainTextLength = plainText.length; + int cipherTextLength = NONCE_LENGTH + plainTextLength; + int lengthBufferLength = writeLength ? SIZE_LENGTH : 0; + byte[] cipherText = new byte[lengthBufferLength + cipherTextLength]; + int inputLength = plainTextLength; + int inputOffset = 0; + int outputOffset = lengthBufferLength + NONCE_LENGTH; + try { + System.arraycopy(nonce, 0, ctrIV, 0, NONCE_LENGTH); + IvParameterSpec spec = new IvParameterSpec(ctrIV); + cipher.init(Cipher.ENCRYPT_MODE, aesKey, spec); + + // Breaking encryption into multiple updates, to trigger h/w acceleration in Java 9+ + while (inputLength > CHUNK_LENGTH) { + int written = cipher.update(plainText, inputOffset, CHUNK_LENGTH, cipherText, outputOffset); + inputOffset += CHUNK_LENGTH; + outputOffset += written; + inputLength -= CHUNK_LENGTH; + } + + cipher.doFinal(plainText, inputOffset, inputLength, cipherText, outputOffset); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to encrypt", e); + } + + // Add ciphertext length + if (writeLength) { + System.arraycopy(BytesUtils.intToBytes(cipherTextLength), 0, cipherText, 0, lengthBufferLength); + } + // Add the nonce + System.arraycopy(nonce, 0, cipherText, lengthBufferLength, NONCE_LENGTH); + + return cipherText; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java new file mode 100644 index 000000000000..b7072216ad38 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java @@ -0,0 +1,161 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.AEADBadTagException; +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +public class AesGcmDecryptor + extends AesCipher + implements BlockCipher.Decryptor +{ + AesGcmDecryptor(byte[] keyBytes) + { + super(AesMode.GCM, keyBytes); + + try { + cipher = Cipher.getInstance(AesMode.GCM.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to create GCM cipher", e); + } + } + + @Override + public byte[] decrypt(byte[] lengthAndCiphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = lengthAndCiphertext.length - SIZE_LENGTH; + + return decrypt(lengthAndCiphertext, cipherTextOffset, cipherTextLength, aad); + } + + public byte[] decrypt(byte[] ciphertext, int cipherTextOffset, int cipherTextLength, byte[] aad) + { + int plainTextLength = cipherTextLength - GCM_TAG_LENGTH - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoRuntimeException("Wrong input length " + plainTextLength); + } + + // Get the nonce from ciphertext + System.arraycopy(ciphertext, cipherTextOffset, localNonce, 0, NONCE_LENGTH); + + byte[] plainText = new byte[plainTextLength]; + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + int outputOffset = 0; + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, localNonce); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(ciphertext, inputOffset, inputLength, plainText, outputOffset); + } + catch (AEADBadTagException e) { + throw new TagVerificationException("GCM tag check failed", e); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to decrypt", e); + } + + return plainText; + } + + public ByteBuffer decrypt(ByteBuffer ciphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = ciphertext.limit() - ciphertext.position() - SIZE_LENGTH; + int plainTextLength = cipherTextLength - GCM_TAG_LENGTH - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoRuntimeException("Wrong input length " + plainTextLength); + } + + ciphertext.position(ciphertext.position() + cipherTextOffset); + // Get the nonce from ciphertext + ciphertext.get(localNonce); + + // Reuse the input buffer as the output buffer + ByteBuffer plainText = ciphertext.slice(); + plainText.limit(plainTextLength); + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, localNonce); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(ciphertext, plainText); + plainText.flip(); + } + catch (AEADBadTagException e) { + throw new TagVerificationException("GCM tag check failed", e); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to decrypt", e); + } + + return plainText; + } + + @Override + public byte[] decrypt(InputStream from, byte[] aad) + throws IOException + { + byte[] lengthBuffer = new byte[SIZE_LENGTH]; + int gotBytes = 0; + + // Read the length of encrypted Thrift structure + while (gotBytes < SIZE_LENGTH) { + int n = from.read(lengthBuffer, gotBytes, SIZE_LENGTH - gotBytes); + if (n <= 0) { + throw new IOException("Tried to read int (4 bytes), but only got " + gotBytes + " bytes."); + } + gotBytes += n; + } + + final int ciphertextLength = ((lengthBuffer[3] & 0xff) << 24) + | ((lengthBuffer[2] & 0xff) << 16) + | ((lengthBuffer[1] & 0xff) << 8) + | ((lengthBuffer[0] & 0xff)); + + if (ciphertextLength < 1) { + throw new IOException("Wrong length of encrypted metadata: " + ciphertextLength); + } + + byte[] ciphertextBuffer = new byte[ciphertextLength]; + gotBytes = 0; + // Read the encrypted structure contents + while (gotBytes < ciphertextLength) { + int n = from.read(ciphertextBuffer, gotBytes, ciphertextLength - gotBytes); + if (n <= 0) { + throw new IOException( + "Tried to read " + ciphertextLength + " bytes, but only got " + gotBytes + " bytes."); + } + gotBytes += n; + } + + // Decrypt the structure contents + return decrypt(ciphertextBuffer, 0, ciphertextLength, aad); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java new file mode 100644 index 000000000000..82b5b6d24fb9 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; + +import java.security.GeneralSecurityException; + +public class AesGcmEncryptor + extends AesCipher + implements BlockCipher.Encryptor +{ + private long operationCounter; + + AesGcmEncryptor(byte[] keyBytes) + { + super(AesMode.GCM, keyBytes); + operationCounter = 0; + + try { + cipher = Cipher.getInstance(AesMode.GCM.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to create GCM cipher", e); + } + } + + @Override + public byte[] encrypt(byte[] plainText, byte[] aad) + { + return encrypt(true, plainText, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] aad) + { + randomGenerator.nextBytes(localNonce); + return encrypt(writeLength, plainText, localNonce, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] nonce, byte[] aad) + { + if (operationCounter > GCM_RANDOM_IV_SAME_KEY_MAX_OPS) { + throw new ParquetCryptoRuntimeException( + "Exceeded limit of AES GCM encryption operations with same key and random IV"); + } + operationCounter++; + + if (nonce.length != NONCE_LENGTH) { + throw new ParquetCryptoRuntimeException("Wrong nonce length " + nonce.length); + } + int plainTextLength = plainText.length; + int cipherTextLength = NONCE_LENGTH + plainTextLength + GCM_TAG_LENGTH; + int lengthBufferLength = writeLength ? SIZE_LENGTH : 0; + byte[] cipherText = new byte[lengthBufferLength + cipherTextLength]; + int inputLength = plainTextLength; + int inputOffset = 0; + int outputOffset = lengthBufferLength + NONCE_LENGTH; + + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, nonce); + cipher.init(Cipher.ENCRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(plainText, inputOffset, inputLength, cipherText, outputOffset); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoRuntimeException("Failed to encrypt", e); + } + + // Add ciphertext length + if (writeLength) { + System.arraycopy(BytesUtils.intToBytes(cipherTextLength), 0, cipherText, 0, lengthBufferLength); + } + // Add the nonce + System.arraycopy(nonce, 0, cipherText, lengthBufferLength, NONCE_LENGTH); + + return cipherText; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java new file mode 100644 index 000000000000..e8affac6c9f0 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public enum AesMode +{ + GCM("AES/GCM/NoPadding"), + CTR("AES/CTR/NoPadding"); + + private final String cipherName; + + AesMode(String cipherName) + { + this.cipherName = cipherName; + } + + public String getCipherName() + { + return cipherName; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionProperties.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionProperties.java new file mode 100644 index 000000000000..23639fe28e46 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionProperties.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.hadoop.metadata.ColumnPath; + +/** + * This class is only required for setting explicit column decryption keys - + * to override key retriever (or to provide keys when key metadata and/or + * key retriever are not available) + */ +public class ColumnDecryptionProperties +{ + private final ColumnPath columnPath; + private final byte[] keyBytes; + + private ColumnDecryptionProperties(ColumnPath columnPath, byte[] keyBytes) + { + if (null == columnPath) { + throw new IllegalArgumentException("Null column path"); + } + if (null == keyBytes) { + throw new IllegalArgumentException("Null key for column " + columnPath); + } + if (!(keyBytes.length == 16 || keyBytes.length == 24 || keyBytes.length == 32)) { + throw new IllegalArgumentException("Wrong key length: " + keyBytes.length + " on column: " + columnPath); + } + + this.columnPath = columnPath; + this.keyBytes = keyBytes; + } + + /** + * Convenience builder for regular (not nested) columns. + * + * @param name Flat column name + * @return Builder + */ + public static Builder builder(String name) + { + return builder(ColumnPath.get(name)); + } + + public static Builder builder(ColumnPath path) + { + return new Builder(path); + } + + public static class Builder + { + private final ColumnPath columnPath; + private byte[] keyBytes; + + private Builder(ColumnPath path) + { + this.columnPath = path; + } + + /** + * Set an explicit column key. + * If applied on a file that contains key metadata for this column - + * the metadata will be ignored, the column will be decrypted with this key. + * However, if the column was encrypted with the footer key, it will also be decrypted with the + * footer key, and the column key passed in this method will be ignored. + * + * @param columnKey Key length must be either 16, 24 or 32 bytes. + * @return Builder + */ + public Builder withKey(byte[] columnKey) + { + if (null != this.keyBytes) { + throw new IllegalStateException("Key already set on column: " + columnPath); + } + this.keyBytes = new byte[columnKey.length]; + System.arraycopy(columnKey, 0, this.keyBytes, 0, columnKey.length); + + return this; + } + + public ColumnDecryptionProperties build() + { + return new ColumnDecryptionProperties(columnPath, keyBytes); + } + } + + public ColumnPath getPath() + { + return columnPath; + } + + public byte[] getKeyBytes() + { + return keyBytes; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnEncryptionProperties.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnEncryptionProperties.java new file mode 100644 index 000000000000..c3a243ee8aed --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnEncryptionProperties.java @@ -0,0 +1,202 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.nio.charset.StandardCharsets; + +public class ColumnEncryptionProperties +{ + private final boolean encrypted; + private final ColumnPath columnPath; + private final boolean encryptedWithFooterKey; + private final byte[] keyBytes; + private final byte[] keyMetaData; + + private ColumnEncryptionProperties(boolean encrypted, ColumnPath columnPath, byte[] keyBytes, byte[] keyMetaData) + { + if (null == columnPath) { + throw new IllegalArgumentException("Null column path"); + } + if (!encrypted) { + if (null != keyBytes) { + throw new IllegalArgumentException("Setting key on unencrypted column: " + columnPath); + } + if (null != keyMetaData) { + throw new IllegalArgumentException("Setting key metadata on unencrypted column: " + columnPath); + } + } + if ((null != keyBytes) && !(keyBytes.length == 16 || keyBytes.length == 24 || keyBytes.length == 32)) { + throw new IllegalArgumentException("Wrong key length: " + keyBytes.length + ". Column: " + columnPath); + } + encryptedWithFooterKey = (encrypted && (null == keyBytes)); + if (encryptedWithFooterKey && (null != keyMetaData)) { + throw new IllegalArgumentException( + "Setting key metadata on column encrypted with footer key: " + columnPath); + } + + this.encrypted = encrypted; + this.columnPath = columnPath; + this.keyBytes = keyBytes; + this.keyMetaData = keyMetaData; + } + + /** + * Convenience builder for regular (not nested) columns. + * To make sure column name is not misspelled or misplaced, + * file writer will verify that column is in file schema. + * + * @param name Flat column name + * @return Builder + */ + public static Builder builder(String name) + { + return builder(ColumnPath.get(name), true); + } + + /** + * Builder for encrypted columns. + * To make sure column path is not misspelled or misplaced, + * file writer will verify this column is in file schema. + * + * @param path Column path + * @return Builder + */ + public static Builder builder(ColumnPath path) + { + return builder(path, true); + } + + /** + * Builder for encrypted columns. + * To make sure column path is not misspelled or misplaced, + * file writer will verify this column is in file schema. + * + * @param path Column path + * @param encrypt whether or not this column to be encrypted + * @return Builder + */ + public static Builder builder(ColumnPath path, boolean encrypt) + { + return new Builder(path, encrypt); + } + + public static class Builder + { + private final boolean encrypted; + private final ColumnPath columnPath; + + private byte[] keyBytes; + private byte[] keyMetaData; + + private Builder(ColumnPath path, boolean encrypted) + { + this.encrypted = encrypted; + this.columnPath = path; + } + + /** + * Set a column-specific key. + * If key is not set on an encrypted column, the column will + * be encrypted with the footer key. + * + * @param columnKey Key length must be either 16, 24 or 32 bytes. + * @return Builder + */ + public Builder withKey(byte[] columnKey) + { + if (null == columnKey) { + return this; + } + if (null != this.keyBytes) { + throw new IllegalStateException("Key already set on column: " + columnPath); + } + this.keyBytes = new byte[columnKey.length]; + System.arraycopy(columnKey, 0, this.keyBytes, 0, columnKey.length); + + return this; + } + + /** + * Set a key retrieval metadata. + * use either withKeyMetaData or withKeyID, not both. + * + * @param keyMetaData arbitrary byte array with encryption key metadata + * @return Builder + */ + public Builder withKeyMetaData(byte[] keyMetaData) + { + if (null == keyMetaData) { + return this; + } + if (null != this.keyMetaData) { + throw new IllegalStateException("Key metadata already set on column: " + columnPath); + } + this.keyMetaData = keyMetaData; + + return this; + } + + /** + * Set a key retrieval metadata (converted from String). + * use either withKeyMetaData or withKeyID, not both. + * + * @param keyId will be converted to metadata (UTF-8 array). + * @return Builder + */ + public Builder withKeyID(String keyId) + { + if (null == keyId) { + return this; + } + byte[] metaData = keyId.getBytes(StandardCharsets.UTF_8); + + return withKeyMetaData(metaData); + } + + public ColumnEncryptionProperties build() + { + return new ColumnEncryptionProperties(encrypted, columnPath, keyBytes, keyMetaData); + } + } + + public ColumnPath getPath() + { + return columnPath; + } + + public boolean isEncrypted() + { + return encrypted; + } + + public byte[] getKeyBytes() + { + return keyBytes; + } + + public boolean isEncryptedWithFooterKey() + { + if (!encrypted) { + return false; + } + return encryptedWithFooterKey; + } + + public byte[] getKeyMetaData() + { + return keyMetaData; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java new file mode 100644 index 000000000000..6967cb11c281 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +/** + * Interface for classes retrieving encryption keys using the key metadata. + * Implementations must be thread-safe, if same KeyRetriever object is passed to multiple file readers. + */ +public interface DecryptionKeyRetriever +{ + /** + * Returns encryption key using the key metadata. + * If your key retrieval code throws runtime exceptions related to access control (permission) problems + * (such as Hadoop AccessControlException), catch them and throw the KeyAccessDeniedException. + * + * @param keyMetaData arbitrary byte array with encryption key metadata + * @return encryption key. Key length can be either 16, 24 or 32 bytes. + * @throws KeyAccessDeniedException thrown upon access control problems (authentication or authorization) + * @throws ParquetCryptoRuntimeException thrown upon key retrieval problems unrelated to access control + */ + byte[] getKey(byte[] keyMetaData) + throws KeyAccessDeniedException, ParquetCryptoRuntimeException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java new file mode 100644 index 000000000000..6c84b759881a --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java @@ -0,0 +1,277 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.util.HashMap; +import java.util.Map; + +public class FileDecryptionProperties +{ + private static final boolean CHECK_SIGNATURE = true; + private static final boolean ALLOW_PLAINTEXT_FILES = false; + + private final byte[] footerKey; + private final DecryptionKeyRetriever keyRetriever; + private final byte[] aadPrefix; + private final AADPrefixVerifier aadPrefixVerifier; + private final Map columnPropertyMap; + private final boolean checkPlaintextFooterIntegrity; + private final boolean allowPlaintextFiles; + + private FileDecryptionProperties( + byte[] footerKey, + DecryptionKeyRetriever keyRetriever, + boolean checkPlaintextFooterIntegrity, + byte[] aadPrefix, + AADPrefixVerifier aadPrefixVerifier, + Map columnPropertyMap, + boolean allowPlaintextFiles) + { + if ((null == footerKey) && (null == keyRetriever) && (null == columnPropertyMap)) { + throw new IllegalArgumentException("No decryption properties are specified"); + } + if ((null != footerKey) && !(footerKey.length == 16 || footerKey.length == 24 || footerKey.length == 32)) { + throw new IllegalArgumentException("Wrong footer key length " + footerKey.length); + } + if ((null == footerKey) && checkPlaintextFooterIntegrity && (null == keyRetriever)) { + throw new IllegalArgumentException( + "Can't check footer integrity with null footer key and null key retriever"); + } + + this.footerKey = footerKey; + this.checkPlaintextFooterIntegrity = checkPlaintextFooterIntegrity; + this.keyRetriever = keyRetriever; + this.aadPrefix = aadPrefix; + this.columnPropertyMap = columnPropertyMap; + this.aadPrefixVerifier = aadPrefixVerifier; + this.allowPlaintextFiles = allowPlaintextFiles; + } + + public static Builder builder() + { + return new Builder(); + } + + public byte[] getFooterKey() + { + return footerKey; + } + + public byte[] getColumnKey(ColumnPath path) + { + if (null == columnPropertyMap) { + return null; + } + ColumnDecryptionProperties columnDecryptionProperties = columnPropertyMap.get(path); + if (null == columnDecryptionProperties) { + return null; + } + + return columnDecryptionProperties.getKeyBytes(); + } + + public DecryptionKeyRetriever getKeyRetriever() + { + return keyRetriever; + } + + public byte[] getAADPrefix() + { + return aadPrefix; + } + + public boolean checkFooterIntegrity() + { + return checkPlaintextFooterIntegrity; + } + + boolean plaintextFilesAllowed() + { + return allowPlaintextFiles; + } + + AADPrefixVerifier getAADPrefixVerifier() + { + return aadPrefixVerifier; + } + + public static class Builder + { + private byte[] footerKeyBytes; + private DecryptionKeyRetriever keyRetriever; + private byte[] aadPrefixBytes; + private AADPrefixVerifier aadPrefixVerifier; + private Map columnPropertyMap; + private boolean checkPlaintextFooterIntegrity; + private boolean plaintextFilesAllowed; + + private Builder() + { + this.checkPlaintextFooterIntegrity = CHECK_SIGNATURE; + this.plaintextFilesAllowed = ALLOW_PLAINTEXT_FILES; + } + + /** + * Set an explicit footer key. If applied on a file that contains footer key metadata - + * the metadata will be ignored, the footer will be decrypted/verified with this key. + * If explicit key is not set, footer key will be fetched from key retriever. + * + * @param footerKey Key length must be either 16, 24 or 32 bytes. + * @return Builder + */ + public Builder withFooterKey(byte[] footerKey) + { + if (null == footerKey) { + return this; + } + if (null != this.footerKeyBytes) { + throw new IllegalStateException("Footer key already set"); + } + this.footerKeyBytes = new byte[footerKey.length]; + System.arraycopy(footerKey, 0, this.footerKeyBytes, 0, footerKey.length); + + return this; + } + + /** + * Set explicit column keys (decryption properties). + * Its also possible to set a key retriever on this file decryption properties object. + * Upon reading, availability of explicit keys is checked before invocation of the retriever callback. + * If an explicit key is available for a footer or a column, its key metadata will be ignored. + * + * @param columnProperties Explicit column decryption keys + * @return Builder + */ + public Builder withColumnKeys(Map columnProperties) + { + if (null == columnProperties) { + return this; + } + if (null != this.columnPropertyMap) { + throw new IllegalStateException("Column properties already set"); + } + // Copy the map to make column properties immutable + this.columnPropertyMap = new HashMap(columnProperties); + + return this; + } + + /** + * Set a key retriever callback. It is also possible to + * set explicit footer or column keys on this file property object. Upon file decryption, + * availability of explicit keys is checked before invocation of the retriever callback. + * If an explicit key is available for a footer or a column, its key metadata will + * be ignored. + * + * @param keyRetriever Key retriever object + * @return Builder + */ + public Builder withKeyRetriever(DecryptionKeyRetriever keyRetriever) + { + if (null == keyRetriever) { + return this; + } + if (null != this.keyRetriever) { + throw new IllegalStateException("Key retriever already set"); + } + this.keyRetriever = keyRetriever; + + return this; + } + + /** + * Skip integrity verification of plaintext footers. + * If not called, integrity of plaintext footers will be checked in runtime, and an exception will + * be thrown in the following situations: + * - footer signing key is not available (not passed, or not found by key retriever) + * - footer content doesn't match the signature + * + * @return Builder + */ + public Builder withoutFooterSignatureVerification() + { + this.checkPlaintextFooterIntegrity = false; + return this; + } + + /** + * Explicitly supply the file AAD prefix. + * A must when a prefix is used for file encryption, but not stored in file. + * If AAD prefix is stored in file, it will be compared to the explicitly supplied value + * and an exception will be thrown if they differ. + * + * @param aadPrefixBytes AAD Prefix + * @return Builder + */ + public Builder withAADPrefix(byte[] aadPrefixBytes) + { + if (null == aadPrefixBytes) { + return this; + } + if (null != this.aadPrefixBytes) { + throw new IllegalStateException("AAD Prefix already set"); + } + this.aadPrefixBytes = aadPrefixBytes; + + return this; + } + + /** + * Set callback for verification of AAD Prefixes stored in file. + * + * @param aadPrefixVerifier AAD prefix verification object + * @return Builder + */ + public Builder withAADPrefixVerifier(AADPrefixVerifier aadPrefixVerifier) + { + if (null == aadPrefixVerifier) { + return this; + } + if (null != this.aadPrefixVerifier) { + throw new IllegalStateException("AAD Prefix verifier already set"); + } + this.aadPrefixVerifier = aadPrefixVerifier; + + return this; + } + + /** + * By default, reading plaintext (unencrypted) files is not allowed when using a decryptor + * - in order to detect files that were not encrypted by mistake. + * However, the default behavior can be overriden by calling this method. + * The caller should use then a different method to ensure encryption of files with sensitive data. + * + * @return Builder + */ + public Builder withPlaintextFilesAllowed() + { + this.plaintextFilesAllowed = true; + return this; + } + + public FileDecryptionProperties build() + { + return new FileDecryptionProperties( + footerKeyBytes, + keyRetriever, + checkPlaintextFooterIntegrity, + aadPrefixBytes, + aadPrefixVerifier, + columnPropertyMap, + plaintextFilesAllowed); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileEncryptionProperties.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileEncryptionProperties.java new file mode 100644 index 000000000000..b590fc5db412 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileEncryptionProperties.java @@ -0,0 +1,331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.crypto; + +import org.apache.parquet.format.EncryptionAlgorithm; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.HashMap; +import java.util.Map; + +import static io.trino.parquet.crypto.AesCipher.AAD_FILE_UNIQUE_LENGTH; + +public class FileEncryptionProperties +{ + private static final ParquetCipher ALGORITHM_DEFAULT = ParquetCipher.AES_GCM_V1; + private static final boolean ENCRYPTED_FOOTER_DEFAULT = true; + private static final boolean COMPLETE_COLUMN_ENCRYPTION_DEFAULT = false; + + private final EncryptionAlgorithm algorithm; + private final boolean encryptedFooter; + private final byte[] footerKey; + private final byte[] footerKeyMetadata; + private final byte[] fileAAD; + private final Map columnPropertyMap; + private final boolean completeColumnEncryption; + + private FileEncryptionProperties( + ParquetCipher cipher, + byte[] footerKey, + byte[] footerKeyMetadata, + boolean encryptedFooter, + byte[] aadPrefix, + boolean storeAadPrefixInFile, + Map columnPropertyMap, + boolean completeColumnEncryption) + { + if (null == footerKey) { + throw new IllegalArgumentException("Footer key is null"); + } + if (!(footerKey.length == 16 || footerKey.length == 24 || footerKey.length == 32)) { + throw new IllegalArgumentException("Wrong footer key length " + footerKey.length); + } + if (null != columnPropertyMap) { + if (columnPropertyMap.isEmpty()) { + throw new IllegalArgumentException("No encrypted columns"); + } + } + else { + if (completeColumnEncryption) { + throw new IllegalArgumentException("Encrypted columns are not specified, cannot complete"); + } + } + + SecureRandom random = new SecureRandom(); + byte[] aadFileUnique = new byte[AAD_FILE_UNIQUE_LENGTH]; + random.nextBytes(aadFileUnique); + + boolean supplyAadPrefix = false; + if (null == aadPrefix) { + this.fileAAD = aadFileUnique; + } + else { + this.fileAAD = AesCipher.concatByteArrays(aadPrefix, aadFileUnique); + if (!storeAadPrefixInFile) { + supplyAadPrefix = true; + } + } + + this.algorithm = cipher.getEncryptionAlgorithm(); + + if (algorithm.isSetAES_GCM_V1()) { + algorithm.getAES_GCM_V1().setAad_file_unique(aadFileUnique); + algorithm.getAES_GCM_V1().setSupply_aad_prefix(supplyAadPrefix); + if (null != aadPrefix && storeAadPrefixInFile) { + algorithm.getAES_GCM_V1().setAad_prefix(aadPrefix); + } + } + else { + algorithm.getAES_GCM_CTR_V1().setAad_file_unique(aadFileUnique); + algorithm.getAES_GCM_CTR_V1().setSupply_aad_prefix(supplyAadPrefix); + if (null != aadPrefix && storeAadPrefixInFile) { + algorithm.getAES_GCM_CTR_V1().setAad_prefix(aadPrefix); + } + } + + this.footerKey = footerKey; + this.footerKeyMetadata = footerKeyMetadata; + this.encryptedFooter = encryptedFooter; + this.columnPropertyMap = columnPropertyMap; + this.completeColumnEncryption = completeColumnEncryption; + } + + /** + * @param footerKey Encryption key for file footer and some (or all) columns. + * Key length must be either 16, 24 or 32 bytes. + * If null, footer won't be encrypted. At least one column must be encrypted then. + * @return Builder + */ + public static Builder builder(byte[] footerKey) + { + return new Builder(footerKey); + } + + public EncryptionAlgorithm getAlgorithm() + { + return algorithm; + } + + public byte[] getFooterKey() + { + return footerKey; + } + + public byte[] getFooterKeyMetadata() + { + return footerKeyMetadata; + } + + public Map getEncryptedColumns() + { + return columnPropertyMap; + } + + public ColumnEncryptionProperties getColumnProperties(ColumnPath columnPath) + { + if (null == columnPropertyMap) { + // encrypted, with footer key + return ColumnEncryptionProperties.builder(columnPath, true).build(); + } + else { + ColumnEncryptionProperties columnProperties = columnPropertyMap.get(columnPath); + if (null != columnProperties) { + return columnProperties; + } + else { // not set explicitly + if (completeColumnEncryption) { + // encrypted with footer key + return ColumnEncryptionProperties.builder(columnPath, true).build(); + } + else { + // plaintext column + return ColumnEncryptionProperties.builder(columnPath, false).build(); + } + } + } + } + + public byte[] getFileAAD() + { + return fileAAD; + } + + public boolean encryptedFooter() + { + return encryptedFooter; + } + + public static class Builder + { + private byte[] footerKeyBytes; + private boolean encryptedFooter; + private ParquetCipher parquetCipher; + private byte[] footerKeyMetadata; + private byte[] aadPrefix; + private Map columnPropertyMap; + private boolean storeAadPrefixInFile; + private boolean completeColumnEncryption; + + private Builder(byte[] footerKey) + { + this.parquetCipher = ALGORITHM_DEFAULT; + this.encryptedFooter = ENCRYPTED_FOOTER_DEFAULT; + this.completeColumnEncryption = COMPLETE_COLUMN_ENCRYPTION_DEFAULT; + this.footerKeyBytes = new byte[footerKey.length]; + System.arraycopy(footerKey, 0, this.footerKeyBytes, 0, footerKey.length); + } + + /** + * Create files with plaintext footer. + * If not called, the files will be created with encrypted footer (default). + * + * @return Builder + */ + public Builder withPlaintextFooter() + { + this.encryptedFooter = false; + return this; + } + + /** + * Set encryption algorithm. + * If not called, files will be encrypted with AES_GCM_V1 (default). + * + * @param parquetCipher Encryption algorithm + * @return Builder + */ + public Builder withAlgorithm(ParquetCipher parquetCipher) + { + this.parquetCipher = parquetCipher; + return this; + } + + /** + * Set a key retrieval metadata (converted from String). + * Use either withFooterKeyMetaData or withFooterKeyID, not both. + * + * @param keyID will be converted to metadata (UTF-8 array). + * @return Builder + */ + public Builder withFooterKeyID(String keyID) + { + if (null == keyID) { + return this; + } + + return withFooterKeyMetadata(keyID.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Set a key retrieval metadata. + * Use either withFooterKeyMetaData or withFooterKeyID, not both. + * + * @param footerKeyMetadata Key metadata + * @return Builder + */ + public Builder withFooterKeyMetadata(byte[] footerKeyMetadata) + { + if (null == footerKeyMetadata) { + return this; + } + if (null != this.footerKeyMetadata) { + throw new IllegalStateException("Footer key metadata already set"); + } + this.footerKeyMetadata = footerKeyMetadata; + + return this; + } + + /** + * Set the file AAD Prefix. + * + * @param aadPrefixBytes AAD Prefix + * @return Builder + */ + public Builder withAADPrefix(byte[] aadPrefixBytes) + { + if (null == aadPrefixBytes) { + return this; + } + if (null != this.aadPrefix) { + throw new IllegalStateException("AAD Prefix already set"); + } + this.aadPrefix = aadPrefixBytes; + this.storeAadPrefixInFile = true; + + return this; + } + + /** + * Skip storing AAD Prefix in file metadata. + * If not called, and if AAD Prefix is set, it will be stored. + * + * @return Builder + */ + public Builder withoutAADPrefixStorage() + { + if (null == this.aadPrefix) { + throw new IllegalStateException("AAD Prefix not yet set"); + } + this.storeAadPrefixInFile = false; + + return this; + } + + /** + * Set the list of encrypted columns and their properties (keys etc). + * If not called, all columns will be encrypted with the footer key. + * If called, the file columns not in the list will be left unencrypted. + * + * @param encryptedColumns Columns to be encrypted + * @return Builder + */ + public Builder withEncryptedColumns(Map encryptedColumns) + { + if (null == encryptedColumns) { + return this; + } + if (null != this.columnPropertyMap) { + throw new IllegalStateException("Column properties already set"); + } + // Copy the map to make column properties immutable + this.columnPropertyMap = new HashMap(encryptedColumns); + + return this; + } + + public Builder withCompleteColumnEncryption() + { + this.completeColumnEncryption = true; + + return this; + } + + public FileEncryptionProperties build() + { + return new FileEncryptionProperties( + parquetCipher, + footerKeyBytes, + footerKeyMetadata, + encryptedFooter, + aadPrefix, + storeAadPrefixInFile, + columnPropertyMap, + completeColumnEncryption); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnChunkMetaData.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..40406c6efbbe --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnChunkMetaData.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.trino.parquet.metadata.ColumnChunkMetadata; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +public class HiddenColumnChunkMetaData + extends ColumnChunkMetadata +{ + private final ColumnPath path; + private final String filePath; + + public HiddenColumnChunkMetaData(ColumnPath path, String filePath) + { + super(null, null); + this.path = path; + this.filePath = filePath; + } + + public static boolean isHiddenColumn(ColumnChunkMetadata column) + { + return column instanceof HiddenColumnChunkMetaData; + } + + @Override + public long getFirstDataPageOffset() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } + + @Override + public long getDictionaryPageOffset() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } + + @Override + public long getValueCount() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } + + @Override + public long getTotalUncompressedSize() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } + + @Override + public long getTotalSize() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } + + @Override + public Statistics getStatistics() + { + throw new HiddenColumnException(this.path.toArray(), this.filePath); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnException.java new file mode 100644 index 000000000000..806eff57eb01 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/HiddenColumnException.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.ParquetRuntimeException; + +import java.util.Arrays; + +public class HiddenColumnException + extends ParquetRuntimeException +{ + private static final long serialVersionUID = 1L; + + public HiddenColumnException(String[] columnPath, String filePath) + { + super("User does not have access to the encryption key for encrypted column = " + Arrays.toString(columnPath) + " for file: " + filePath); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalColumnDecryptionSetup.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalColumnDecryptionSetup.java new file mode 100644 index 000000000000..a818fbfbd92f --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalColumnDecryptionSetup.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +public class InternalColumnDecryptionSetup +{ + private final ColumnPath columnPath; + private final boolean isEncrypted; + private final boolean isEncryptedWithFooterKey; + private final BlockCipher.Decryptor dataDecryptor; + private final BlockCipher.Decryptor metaDataDecryptor; + private final int columnOrdinal; + private final byte[] keyMetadata; + + InternalColumnDecryptionSetup( + ColumnPath path, + boolean encrypted, + boolean isEncryptedWithFooterKey, + BlockCipher.Decryptor dataDecryptor, + BlockCipher.Decryptor metaDataDecryptor, + int columnOrdinal, + byte[] keyMetadata) + { + this.columnPath = path; + this.isEncrypted = encrypted; + this.isEncryptedWithFooterKey = isEncryptedWithFooterKey; + this.dataDecryptor = dataDecryptor; + this.metaDataDecryptor = metaDataDecryptor; + this.columnOrdinal = columnOrdinal; + this.keyMetadata = keyMetadata; + } + + public boolean isEncrypted() + { + return isEncrypted; + } + + public BlockCipher.Decryptor getDataDecryptor() + { + return dataDecryptor; + } + + public BlockCipher.Decryptor getMetaDataDecryptor() + { + return metaDataDecryptor; + } + + boolean isEncryptedWithFooterKey() + { + return isEncryptedWithFooterKey; + } + + ColumnPath getPath() + { + return columnPath; + } + + public int getOrdinal() + { + return columnOrdinal; + } + + byte[] getKeyMetadata() + { + return keyMetadata; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalFileDecryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalFileDecryptor.java new file mode 100644 index 000000000000..da31223e52fd --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/InternalFileDecryptor.java @@ -0,0 +1,344 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.airlift.log.Logger; +import org.apache.parquet.format.BlockCipher; +import org.apache.parquet.format.EncryptionAlgorithm; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.util.Arrays; +import java.util.HashMap; + +public class InternalFileDecryptor +{ + private static final Logger LOG = Logger.get(InternalFileDecryptor.class); + private final FileDecryptionProperties fileDecryptionProperties; + private final DecryptionKeyRetriever keyRetriever; + private final boolean checkPlaintextFooterIntegrity; + private final byte[] aadPrefixInProperties; + private final AADPrefixVerifier aadPrefixVerifier; + + private byte[] footerKey; + private HashMap columnMap; + private EncryptionAlgorithm algorithm; + private byte[] fileAAD; + private boolean encryptedFooter; + private byte[] footerKeyMetaData; + private boolean fileCryptoMetaDataProcessed; + private BlockCipher.Decryptor aesGcmDecryptorWithFooterKey; + private BlockCipher.Decryptor aesCtrDecryptorWithFooterKey; + private boolean plaintextFile; + + public InternalFileDecryptor(FileDecryptionProperties fileDecryptionProperties) + { + this.fileDecryptionProperties = fileDecryptionProperties; + checkPlaintextFooterIntegrity = fileDecryptionProperties.checkFooterIntegrity(); + footerKey = fileDecryptionProperties.getFooterKey(); + keyRetriever = fileDecryptionProperties.getKeyRetriever(); + aadPrefixInProperties = fileDecryptionProperties.getAADPrefix(); + columnMap = new HashMap(); + this.aadPrefixVerifier = fileDecryptionProperties.getAADPrefixVerifier(); + this.plaintextFile = false; + } + + private BlockCipher.Decryptor getThriftModuleDecryptor(byte[] columnKey) + { + if (null == columnKey) { // Decryptor with footer key + if (null == aesGcmDecryptorWithFooterKey) { + aesGcmDecryptorWithFooterKey = ModuleCipherFactory.getDecryptor(AesMode.GCM, footerKey); + } + return aesGcmDecryptorWithFooterKey; + } + else { // Decryptor with column key + return ModuleCipherFactory.getDecryptor(AesMode.GCM, columnKey); + } + } + + private BlockCipher.Decryptor getDataModuleDecryptor(byte[] columnKey) + { + if (algorithm.isSetAES_GCM_V1()) { + return getThriftModuleDecryptor(columnKey); + } + + // AES_GCM_CTR_V1 + if (null == columnKey) { // Decryptor with footer key + if (null == aesCtrDecryptorWithFooterKey) { + aesCtrDecryptorWithFooterKey = ModuleCipherFactory.getDecryptor(AesMode.CTR, footerKey); + } + return aesCtrDecryptorWithFooterKey; + } + else { // Decryptor with column key + return ModuleCipherFactory.getDecryptor(AesMode.CTR, columnKey); + } + } + + public InternalColumnDecryptionSetup getColumnSetup(ColumnPath path) + { + if (!fileCryptoMetaDataProcessed) { + throw new ParquetCryptoRuntimeException("Haven't parsed the file crypto metadata yet"); + } + InternalColumnDecryptionSetup columnDecryptionSetup = columnMap.get(path); + if (null == columnDecryptionSetup) { + throw new ParquetCryptoRuntimeException("Failed to find decryption setup for column " + path); + } + + return columnDecryptionSetup; + } + + public BlockCipher.Decryptor fetchFooterDecryptor() + { + if (!fileCryptoMetaDataProcessed) { + throw new ParquetCryptoRuntimeException("Haven't parsed the file crypto metadata yet"); + } + + return getThriftModuleDecryptor(null); + } + + public void setFileCryptoMetaData( + EncryptionAlgorithm algorithm, boolean encryptedFooter, byte[] footerKeyMetaData) + { + // first use of the decryptor + if (!fileCryptoMetaDataProcessed) { + fileCryptoMetaDataProcessed = true; + this.encryptedFooter = encryptedFooter; + this.algorithm = algorithm; + this.footerKeyMetaData = footerKeyMetaData; + + byte[] aadFileUnique; + boolean mustSupplyAadPrefix; + boolean fileHasAadPrefix = false; + byte[] aadPrefixInFile = null; + + // Process encryption algorithm metadata + if (algorithm.isSetAES_GCM_V1()) { + if (algorithm.getAES_GCM_V1().isSetAad_prefix()) { + fileHasAadPrefix = true; + aadPrefixInFile = algorithm.getAES_GCM_V1().getAad_prefix(); + } + mustSupplyAadPrefix = algorithm.getAES_GCM_V1().isSupply_aad_prefix(); + aadFileUnique = algorithm.getAES_GCM_V1().getAad_file_unique(); + } + else if (algorithm.isSetAES_GCM_CTR_V1()) { + if (algorithm.getAES_GCM_CTR_V1().isSetAad_prefix()) { + fileHasAadPrefix = true; + aadPrefixInFile = algorithm.getAES_GCM_CTR_V1().getAad_prefix(); + } + mustSupplyAadPrefix = algorithm.getAES_GCM_CTR_V1().isSupply_aad_prefix(); + aadFileUnique = algorithm.getAES_GCM_CTR_V1().getAad_file_unique(); + } + else { + throw new ParquetCryptoRuntimeException("Unsupported algorithm: " + algorithm); + } + + // Handle AAD prefix + byte[] aadPrefix = aadPrefixInProperties; + if (mustSupplyAadPrefix && (null == aadPrefixInProperties)) { + throw new ParquetCryptoRuntimeException("AAD prefix used for file encryption, " + + "but not stored in file and not supplied in decryption properties"); + } + + if (fileHasAadPrefix) { + if (null != aadPrefixInProperties) { + if (!Arrays.equals(aadPrefixInProperties, aadPrefixInFile)) { + throw new ParquetCryptoRuntimeException( + "AAD Prefix in file and in decryption properties is not the same"); + } + } + if (null != aadPrefixVerifier) { + aadPrefixVerifier.verify(aadPrefixInFile); + } + aadPrefix = aadPrefixInFile; + } + else { + if (!mustSupplyAadPrefix && (null != aadPrefixInProperties)) { + throw new ParquetCryptoRuntimeException( + "AAD Prefix set in decryption properties, but was not used for file encryption"); + } + if (null != aadPrefixVerifier) { + throw new ParquetCryptoRuntimeException( + "AAD Prefix Verifier is set, but AAD Prefix not found in file"); + } + } + + if (null == aadPrefix) { + this.fileAAD = aadFileUnique; + } + else { + this.fileAAD = AesCipher.concatByteArrays(aadPrefix, aadFileUnique); + } + + // Get footer key + if (null == footerKey) { // ignore footer key metadata if footer key is explicitly set via API + if (encryptedFooter || checkPlaintextFooterIntegrity) { + if (null == footerKeyMetaData) { + throw new ParquetCryptoRuntimeException("No footer key or key metadata"); + } + if (null == keyRetriever) { + throw new ParquetCryptoRuntimeException("No footer key or key retriever"); + } + + try { + footerKey = keyRetriever.getKey(footerKeyMetaData); + } + catch (KeyAccessDeniedException e) { + throw new KeyAccessDeniedException("Footer key: access denied", e); + } + + if (null == footerKey) { + throw new ParquetCryptoRuntimeException("Footer key unavailable"); + } + } + } + } + else { + // re-use of the decryptor + // check the crypto metadata. + if (!this.algorithm.equals(algorithm)) { + throw new ParquetCryptoRuntimeException("Decryptor re-use: Different algorithm"); + } + if (encryptedFooter != this.encryptedFooter) { + throw new ParquetCryptoRuntimeException("Decryptor re-use: Different footer encryption"); + } + if (!Arrays.equals(this.footerKeyMetaData, footerKeyMetaData)) { + throw new ParquetCryptoRuntimeException("Decryptor re-use: Different footer key metadata"); + } + } + + if (LOG.isDebugEnabled()) { + LOG.debug("File Decryptor. Algo: {}. Encrypted footer: {}", algorithm, encryptedFooter); + } + } + + public InternalColumnDecryptionSetup setColumnCryptoMetadata( + ColumnPath path, boolean encrypted, boolean encryptedWithFooterKey, byte[] keyMetadata, int columnOrdinal) + { + if (!fileCryptoMetaDataProcessed) { + throw new ParquetCryptoRuntimeException("Haven't parsed the file crypto metadata yet"); + } + + InternalColumnDecryptionSetup columnDecryptionSetup = columnMap.get(path); + if (null != columnDecryptionSetup) { + if (columnDecryptionSetup.isEncrypted() != encrypted) { + throw new ParquetCryptoRuntimeException("Re-use: wrong encrypted flag. Column: " + path); + } + if (encrypted) { + if (encryptedWithFooterKey != columnDecryptionSetup.isEncryptedWithFooterKey()) { + throw new ParquetCryptoRuntimeException( + "Re-use: wrong encryption key (column vs footer). Column: " + path); + } + if (!encryptedWithFooterKey && !Arrays.equals(columnDecryptionSetup.getKeyMetadata(), keyMetadata)) { + throw new ParquetCryptoRuntimeException("Decryptor re-use: Different footer key metadata "); + } + } + return columnDecryptionSetup; + } + + if (!encrypted) { + columnDecryptionSetup = + new InternalColumnDecryptionSetup(path, false, false, null, null, columnOrdinal, null); + } + else { + if (encryptedWithFooterKey) { + if (null == footerKey) { + throw new ParquetCryptoRuntimeException("Column " + path + " is encrypted with NULL footer key"); + } + columnDecryptionSetup = new InternalColumnDecryptionSetup( + path, + true, + true, + getDataModuleDecryptor(null), + getThriftModuleDecryptor(null), + columnOrdinal, + null); + + if (LOG.isDebugEnabled()) { + LOG.debug("Column decryption (footer key): {}", path); + } + } + else { // Column is encrypted with column-specific key + byte[] columnKeyBytes = fileDecryptionProperties.getColumnKey(path); + if ((null == columnKeyBytes) && (null != keyMetadata) && (null != keyRetriever)) { + // No explicit column key given via API. Retrieve via key metadata. + try { + columnKeyBytes = keyRetriever.getKey(keyMetadata); + } + catch (KeyAccessDeniedException e) { + throw new KeyAccessDeniedException("Column " + path + ": key access denied", e); + } + } + if (null == columnKeyBytes) { + throw new ParquetCryptoRuntimeException("Column " + path + "is encrypted with NULL column key"); + } + columnDecryptionSetup = new InternalColumnDecryptionSetup( + path, + true, + false, + getDataModuleDecryptor(columnKeyBytes), + getThriftModuleDecryptor(columnKeyBytes), + columnOrdinal, + keyMetadata); + + if (LOG.isDebugEnabled()) { + LOG.debug("Column decryption (column key): {}", path); + } + } + } + columnMap.put(path, columnDecryptionSetup); + + return columnDecryptionSetup; + } + + public byte[] getFileAAD() + { + return this.fileAAD; + } + + public AesGcmEncryptor createSignedFooterEncryptor() + { + if (!fileCryptoMetaDataProcessed) { + throw new ParquetCryptoRuntimeException("Haven't parsed the file crypto metadata yet"); + } + if (encryptedFooter) { + throw new ParquetCryptoRuntimeException("Requesting signed footer encryptor in file with encrypted footer"); + } + + return (AesGcmEncryptor) ModuleCipherFactory.getEncryptor(AesMode.GCM, footerKey); + } + + public boolean checkFooterIntegrity() + { + return checkPlaintextFooterIntegrity; + } + + public boolean plaintextFilesAllowed() + { + return fileDecryptionProperties.plaintextFilesAllowed(); + } + + public void setPlaintextFile() + { + plaintextFile = true; + } + + public boolean plaintextFile() + { + return plaintextFile; + } + + public FileDecryptionProperties getDecryptionProperties() + { + return fileDecryptionProperties; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/KeyAccessDeniedException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/KeyAccessDeniedException.java new file mode 100644 index 000000000000..6b613f41ec7a --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/KeyAccessDeniedException.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public class KeyAccessDeniedException + extends ParquetCryptoRuntimeException +{ + private static final long serialVersionUID = 1L; + + public KeyAccessDeniedException() {} + + public KeyAccessDeniedException(String message, Throwable cause) + { + super(message, cause); + } + + public KeyAccessDeniedException(String message) + { + super(message); + } + + public KeyAccessDeniedException(Throwable cause) + { + super(cause); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleCipherFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleCipherFactory.java new file mode 100644 index 000000000000..60e6fa7df6d1 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleCipherFactory.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; + +public class ModuleCipherFactory +{ + public static final int SIZE_LENGTH = 4; + + private ModuleCipherFactory() + { + } + + public static BlockCipher.Encryptor getEncryptor(AesMode mode, byte[] keyBytes) + { + switch (mode) { + case GCM: + return new AesGcmEncryptor(keyBytes); + case CTR: + return new AesCtrEncryptor(keyBytes); + default: + throw new IllegalArgumentException("AesMode not supported in ModuleCipherFactory: " + mode); + } + } + + public static BlockCipher.Decryptor getDecryptor(AesMode mode, byte[] keyBytes) + { + switch (mode) { + case GCM: + return new AesGcmDecryptor(keyBytes); + case CTR: + return new AesCtrDecryptor(keyBytes); + default: + throw new IllegalArgumentException("AesMode not supported in ModuleCipherFactory: " + mode); + } + } + + // Parquet Module types + public enum ModuleType + { + Footer((byte) 0), + ColumnMetaData((byte) 1), + DataPage((byte) 2), + DictionaryPage((byte) 3), + DataPageHeader((byte) 4), + DictionaryPageHeader((byte) 5), + ColumnIndex((byte) 6), + OffsetIndex((byte) 7), + BloomFilterHeader((byte) 8), + BloomFilterBitset((byte) 9); + + private final byte value; + + ModuleType(byte value) + { + this.value = value; + } + + public byte getValue() + { + return value; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCipher.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCipher.java new file mode 100644 index 000000000000..a64fba00e2a1 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCipher.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.AesGcmCtrV1; +import org.apache.parquet.format.AesGcmV1; +import org.apache.parquet.format.EncryptionAlgorithm; + +public enum ParquetCipher { + AES_GCM_V1 { + @Override + public EncryptionAlgorithm getEncryptionAlgorithm() + { + return EncryptionAlgorithm.AES_GCM_V1(new AesGcmV1()); + } + }, + AES_GCM_CTR_V1 { + @Override + public EncryptionAlgorithm getEncryptionAlgorithm() + { + return EncryptionAlgorithm.AES_GCM_CTR_V1(new AesGcmCtrV1()); + } + }; + + public abstract EncryptionAlgorithm getEncryptionAlgorithm(); +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoRuntimeException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoRuntimeException.java new file mode 100644 index 000000000000..10d1c4238096 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoRuntimeException.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.ParquetRuntimeException; + +/** + * Thrown upon encryption or decryption operation problem + */ +public class ParquetCryptoRuntimeException + extends ParquetRuntimeException +{ + private static final long serialVersionUID = 1L; + + public ParquetCryptoRuntimeException() {} + + public ParquetCryptoRuntimeException(String message, Throwable cause) + { + super(message, cause); + } + + public ParquetCryptoRuntimeException(String message) + { + super(message); + } + + public ParquetCryptoRuntimeException(Throwable cause) + { + super(cause); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TagVerificationException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TagVerificationException.java new file mode 100644 index 000000000000..d76986d43326 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TagVerificationException.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public class TagVerificationException + extends ParquetCryptoRuntimeException +{ + private static final long serialVersionUID = 1L; + + public TagVerificationException() {} + + public TagVerificationException(String message, Throwable cause) + { + super(message, cause); + } + + public TagVerificationException(String message) + { + super(message); + } + + public TagVerificationException(Throwable cause) + { + super(cause); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoCryptoConfigurationUtil.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoCryptoConfigurationUtil.java new file mode 100644 index 000000000000..f6ad3f3e0878 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoCryptoConfigurationUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.airlift.log.Logger; + +public class TrinoCryptoConfigurationUtil +{ + public static final Logger LOG = Logger.get(TrinoCryptoConfigurationUtil.class); + + private TrinoCryptoConfigurationUtil() + { + } + + public static Class> getClassFromConfig(String className, Class> assignableFrom) + { + try { + final Class> foundClass = Class.forName(className); + if (!assignableFrom.isAssignableFrom(foundClass)) { + LOG.warn("class " + className + " is not a subclass of " + assignableFrom.getCanonicalName()); + return null; + } + return foundClass; + } + catch (ClassNotFoundException e) { + LOG.warn("could not instantiate class " + className, e); + return null; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoDecryptionPropertiesFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoDecryptionPropertiesFactory.java new file mode 100644 index 000000000000..1e00c4879a98 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoDecryptionPropertiesFactory.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; + +public interface TrinoDecryptionPropertiesFactory +{ + // TODO(wyu): maybe create a dedicate config class in org.apache.parquet and convert ParquetReaderOptions to this class? + FileDecryptionProperties getFileDecryptionProperties(io.trino.parquet.ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoVersionedLocalWrap.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoVersionedLocalWrap.java new file mode 100644 index 000000000000..1eb244b4af76 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/TrinoVersionedLocalWrap.java @@ -0,0 +1,193 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.airlift.json.ObjectMapperProvider; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.keytools.TrinoKmsClient; + +import java.io.IOException; +import java.io.StringReader; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +public abstract class TrinoVersionedLocalWrap + implements TrinoKmsClient +{ + protected String kmsInstanceID; + protected String kmsInstanceURL; + protected String kmsToken; + protected ParquetReaderOptions trinoParquetCryptoConfig; + protected ConcurrentMap masterKeyCache; + + public TrinoVersionedLocalWrap() + { + } + + @Override + public void initialize(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, String kmsInstanceURL, String accessToken) + throws KeyAccessDeniedException + { + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + this.masterKeyCache = new ConcurrentHashMap(); + this.trinoParquetCryptoConfig = trinoParquetCryptoConfig; + this.kmsToken = accessToken; + this.initializeInternal(); + } + + @Override + public String wrapKey(byte[] keyBytes, String masterKeyIdentifier) + throws KeyAccessDeniedException + { + return null; + } + + @Override + public byte[] unwrapKey(String wrappedKey, String masterKeyIdentifier) + throws KeyAccessDeniedException + { + LocalKeyWrap keyWrap = LocalKeyWrap.parse(wrappedKey); + String masterKeyVersionedID = masterKeyIdentifier + ":" + keyWrap.getMasterKeyVersion(); + String encryptedEncodedKey = keyWrap.getEncryptedKey(); + byte[] masterKey = this.masterKeyCache.computeIfAbsent(masterKeyVersionedID, (k) -> this.getMasterKeyForVersion(masterKeyIdentifier, keyWrap.getMasterKeyVersion())); + return decryptKeyLocally(encryptedEncodedKey, masterKey, null); + } + + public static byte[] decryptKeyLocally(String encodedEncryptedKey, byte[] masterKeyBytes, byte[] aad) + { + byte[] encryptedKey = Base64.getDecoder().decode(encodedEncryptedKey); + AesGcmDecryptor keyDecryptor = (AesGcmDecryptor) ModuleCipherFactory.getDecryptor(AesMode.GCM, masterKeyBytes); + return keyDecryptor.decrypt(encryptedKey, 0, encryptedKey.length, aad); + } + + private byte[] getMasterKeyForVersion(String keyIdentifier, String keyVersion) + { + this.kmsToken = trinoParquetCryptoConfig.getEncryptionKeyAccessToken(); + byte[] key = this.getMasterKey(keyIdentifier, keyVersion); + this.checkMasterKeyLength(key.length, keyIdentifier, keyVersion); + return key; + } + + private void checkMasterKeyLength(int keyLength, String keyID, String keyVersion) + { + if (16 != keyLength && 24 != keyLength && 32 != keyLength) { + throw new ParquetCryptoRuntimeException("Wrong length: " + keyLength + " of master key: " + keyID + ", version: " + keyVersion); + } + } + + protected abstract MasterKeyWithVersion getMasterKey(String var1) + throws KeyAccessDeniedException; + + protected abstract byte[] getMasterKey(String var1, String var2) + throws KeyAccessDeniedException; + + protected abstract void initializeInternal() + throws KeyAccessDeniedException; + + static class LocalKeyWrap + { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapperProvider().get() + .enable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .enable(MapperFeature.ACCEPT_CASE_INSENSITIVE_ENUMS); + private final String encryptedEncodedKey; + private final String masterKeyVersion; + + private LocalKeyWrap(String masterKeyVersion, String encryptedEncodedKey) + { + this.masterKeyVersion = masterKeyVersion; + this.encryptedEncodedKey = encryptedEncodedKey; + } + + private static String createSerialized(String encryptedEncodedKey, String masterKeyVersion) + { + Map keyWrapMap = new HashMap(3); + keyWrapMap.put("localWrappingType", "LKW1"); + keyWrapMap.put("masterKeyVersion", masterKeyVersion); + keyWrapMap.put("encryptedKey", encryptedEncodedKey); + + try { + return OBJECT_MAPPER.writeValueAsString(keyWrapMap); + } + catch (IOException var4) { + throw new ParquetCryptoRuntimeException("Failed to serialize local key wrap map", var4); + } + } + + static LocalKeyWrap parse(String wrappedKey) + { + Map keyWrapMap; + try { + keyWrapMap = (Map) OBJECT_MAPPER.readValue(new StringReader(wrappedKey), new TypeReference>() { + }); + } + catch (IOException var5) { + throw new ParquetCryptoRuntimeException("Failed to parse local key wrap json " + wrappedKey, var5); + } + + String localWrappingType = (String) keyWrapMap.get("localWrappingType"); + String masterKeyVersion = (String) keyWrapMap.get("masterKeyVersion"); + if (null == localWrappingType) { + if (!"NO_VERSION".equals(masterKeyVersion)) { + throw new ParquetCryptoRuntimeException("No localWrappingType defined for key version: " + masterKeyVersion); + } + } + else if (!"LKW1".equals(localWrappingType)) { + throw new ParquetCryptoRuntimeException("Unsupported localWrappingType: " + localWrappingType); + } + + String encryptedEncodedKey = (String) keyWrapMap.get("encryptedKey"); + return new LocalKeyWrap(masterKeyVersion, encryptedEncodedKey); + } + + String getMasterKeyVersion() + { + return this.masterKeyVersion; + } + + private String getEncryptedKey() + { + return this.encryptedEncodedKey; + } + } + + public static class MasterKeyWithVersion + { + private final byte[] masterKey; + private final String masterKeyVersion; + + public MasterKeyWithVersion(byte[] masterKey, String masterKeyVersion) + { + this.masterKey = masterKey; + this.masterKeyVersion = masterKeyVersion; + } + + private byte[] getKey() + { + return this.masterKey; + } + + private String getVersion() + { + return this.masterKeyVersion; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMaterial.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMaterial.java new file mode 100644 index 000000000000..4ef0d0f22217 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMaterial.java @@ -0,0 +1,242 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.IOException; +import java.io.StringReader; +import java.util.HashMap; +import java.util.Map; + +/** + * KeyMaterial class represents the "key material", keeping the information that allows readers to recover an encryption key (see + * description of the KeyMetadata class). The keytools package (PARQUET-1373) implements the "envelope encryption" pattern, in a + * "single wrapping" or "double wrapping" mode. In the single wrapping mode, the key material is generated by encrypting the + * "data encryption key" (DEK) by a "master key". In the double wrapping mode, the key material is generated by encrypting the DEK + * by a "key encryption key" (KEK), that in turn is encrypted by a "master key". + * + * Key material is kept in a flat json object, with the following fields: + * 1. "keyMaterialType" - a String, with the type of key material. In the current version, only one value is allowed - "PKMT1" (stands + * for "parquet key management tools, version 1"). For external key material storage, this field is written in both "key metadata" and + * "key material" jsons. For internal key material storage, this field is written only once in the common json. + * 2. "isFooterKey" - a boolean. If true, means that the material belongs to a file footer key, and keeps additional information (such as + * KMS instance ID and URL). If false, means that the material belongs to a column key. + * 3. "kmsInstanceID" - a String, with the KMS Instance ID. Written only in footer key material. + * 4. "kmsInstanceURL" - a String, with the KMS Instance URL. Written only in footer key material. + * 5. "masterKeyID" - a String, with the ID of the master key used to generate the material. + * 6. "wrappedDEK" - a String, with the wrapped DEK (base64 encoding). + * 7. "doubleWrapping" - a boolean. If true, means that the material was generated in double wrapping mode. + * If false - in single wrapping mode. + * 8. "keyEncryptionKeyID" - a String, with the ID of the KEK used to generate the material. Written only in double wrapping mode. + * 9. "wrappedKEK" - a String, with the wrapped KEK (base64 encoding). Written only in double wrapping mode. + */ +public class KeyMaterial +{ + static final String KEY_MATERIAL_TYPE_FIELD = "keyMaterialType"; + static final String KEY_MATERIAL_TYPE1 = "PKMT1"; + + static final String FOOTER_KEY_ID_IN_FILE = "footerKey"; + static final String COLUMN_KEY_ID_IN_FILE_PREFIX = "columnKey"; + + private static final String IS_FOOTER_KEY_FIELD = "isFooterKey"; + private static final String DOUBLE_WRAPPING_FIELD = "doubleWrapping"; + private static final String KMS_INSTANCE_ID_FIELD = "kmsInstanceID"; + private static final String KMS_INSTANCE_URL_FIELD = "kmsInstanceURL"; + private static final String MASTER_KEY_ID_FIELD = "masterKeyID"; + private static final String WRAPPED_DEK_FIELD = "wrappedDEK"; + private static final String KEK_ID_FIELD = "keyEncryptionKeyID"; + private static final String WRAPPED_KEK_FIELD = "wrappedKEK"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final boolean isFooterKey; + private final String kmsInstanceID; + private final String kmsInstanceURL; + private final String masterKeyID; + private final boolean isDoubleWrapped; + private final String kekID; + private final String encodedWrappedKEK; + private final String encodedWrappedDEK; + + private KeyMaterial( + boolean isFooterKey, + String kmsInstanceID, + String kmsInstanceURL, + String masterKeyID, + boolean isDoubleWrapped, + String kekID, + String encodedWrappedKEK, + String encodedWrappedDEK) + { + this.isFooterKey = isFooterKey; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + this.masterKeyID = masterKeyID; + this.isDoubleWrapped = isDoubleWrapped; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + this.encodedWrappedDEK = encodedWrappedDEK; + } + + // parses external key material + static KeyMaterial parse(String keyMaterialString) + { + Map keyMaterialJson = null; + try { + keyMaterialJson = OBJECT_MAPPER.readValue( + new StringReader(keyMaterialString), new TypeReference>() {}); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to parse key metadata " + keyMaterialString, e); + } + // 1. External key material - extract "key material type", and make sure it is supported + String keyMaterialType = (String) keyMaterialJson.get(KEY_MATERIAL_TYPE_FIELD); + if (!KEY_MATERIAL_TYPE1.equals(keyMaterialType)) { + throw new ParquetCryptoRuntimeException( + "Wrong key material type: " + keyMaterialType + " vs " + KEY_MATERIAL_TYPE1); + } + // Parse other fields (common to internal and external key material) + return parse(keyMaterialJson); + } + + // parses fields common to internal and external key material + static KeyMaterial parse(Map keyMaterialJson) + { + // 2. Check if "key material" belongs to file footer key + Boolean isFooterKey = (Boolean) keyMaterialJson.get(IS_FOOTER_KEY_FIELD); + String kmsInstanceID = null; + String kmsInstanceURL = null; + if (isFooterKey) { + // 3. For footer key, extract KMS Instance ID + kmsInstanceID = (String) keyMaterialJson.get(KMS_INSTANCE_ID_FIELD); + // 4. For footer key, extract KMS Instance URL + kmsInstanceURL = (String) keyMaterialJson.get(KMS_INSTANCE_URL_FIELD); + } + // 5. Extract master key ID + String masterKeyID = (String) keyMaterialJson.get(MASTER_KEY_ID_FIELD); + // 6. Extract wrapped DEK + String encodedWrappedDEK = (String) keyMaterialJson.get(WRAPPED_DEK_FIELD); + String kekID = null; + String encodedWrappedKEK = null; + // 7. Check if "key material" was generated in double wrapping mode + Boolean isDoubleWrapped = (Boolean) keyMaterialJson.get(DOUBLE_WRAPPING_FIELD); + if (isDoubleWrapped) { + // 8. In double wrapping mode, extract KEK ID + kekID = (String) keyMaterialJson.get(KEK_ID_FIELD); + // 9. In double wrapping mode, extract wrapped KEK + encodedWrappedKEK = (String) keyMaterialJson.get(WRAPPED_KEK_FIELD); + } + + return new KeyMaterial( + isFooterKey, + kmsInstanceID, + kmsInstanceURL, + masterKeyID, + isDoubleWrapped, + kekID, + encodedWrappedKEK, + encodedWrappedDEK); + } + + static String createSerialized( + boolean isFooterKey, + String kmsInstanceID, + String kmsInstanceURL, + String masterKeyID, + boolean isDoubleWrapped, + String kekID, + String encodedWrappedKEK, + String encodedWrappedDEK, + boolean isInternalStorage) + { + Map keyMaterialMap = new HashMap(10); + // 1. Write "key material type" + keyMaterialMap.put(KEY_MATERIAL_TYPE_FIELD, KEY_MATERIAL_TYPE1); + if (isInternalStorage) { + // for internal storage, key material and key metadata are the same. + // adding the "internalStorage" field that belongs to KeyMetadata. + keyMaterialMap.put(KeyMetadata.KEY_MATERIAL_INTERNAL_STORAGE_FIELD, Boolean.TRUE); + } + // 2. Write isFooterKey + keyMaterialMap.put(IS_FOOTER_KEY_FIELD, isFooterKey); + if (isFooterKey) { + // 3. For footer key, write KMS Instance ID + keyMaterialMap.put(KMS_INSTANCE_ID_FIELD, kmsInstanceID); + // 4. For footer key, write KMS Instance URL + keyMaterialMap.put(KMS_INSTANCE_URL_FIELD, kmsInstanceURL); + } + // 5. Write master key ID + keyMaterialMap.put(MASTER_KEY_ID_FIELD, masterKeyID); + // 6. Write wrapped DEK + keyMaterialMap.put(WRAPPED_DEK_FIELD, encodedWrappedDEK); + // 7. Write isDoubleWrapped + keyMaterialMap.put(DOUBLE_WRAPPING_FIELD, isDoubleWrapped); + if (isDoubleWrapped) { + // 8. In double wrapping mode, write KEK ID + keyMaterialMap.put(KEK_ID_FIELD, kekID); + // 9. In double wrapping mode, write wrapped KEK + keyMaterialMap.put(WRAPPED_KEK_FIELD, encodedWrappedKEK); + } + + try { + return OBJECT_MAPPER.writeValueAsString(keyMaterialMap); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to serialize key material", e); + } + } + + boolean isFooterKey() + { + return isFooterKey; + } + + boolean isDoubleWrapped() + { + return isDoubleWrapped; + } + + String getMasterKeyID() + { + return masterKeyID; + } + + String getWrappedDEK() + { + return encodedWrappedDEK; + } + + String getKekID() + { + return kekID; + } + + String getWrappedKEK() + { + return encodedWrappedKEK; + } + + String getKmsInstanceID() + { + return kmsInstanceID; + } + + String getKmsInstanceURL() + { + return kmsInstanceURL; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java new file mode 100644 index 000000000000..54d4d7227dc9 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.IOException; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +/** + * Parquet encryption specification defines "key metadata" as an arbitrary byte array, generated by file writers for each encryption key, + * and passed to the low level API for storage in the file footer . The "key metadata" field is made available to file readers to enable + * recovery of the key. This simple interface can be utilized for implementation of any key management scheme. + * + * The keytools package (PARQUET-1373) implements one approach, of many possible, to key management and to generation of the "key metadata" + * fields. This approach, based on the "envelope encryption" pattern, allows to work with KMS servers. It keeps the actual material, + * required to recover a key, in a "key material" object (see the KeyMaterial class for details). + * + * KeyMetadata class writes (and reads) the "key metadata" field as a flat json object, with the following fields: + * 1. "keyMaterialType" - a String, with the type of key material. In the current version, only one value is allowed - "PKMT1" (stands + * for "parquet key management tools, version 1") + * 2. "internalStorage" - a boolean. If true, means that "key material" is kept inside the "key metadata" field. If false, "key material" + * is kept externally (outside Parquet files) - in this case, "key metadata" keeps a reference to the external "key material". + * 3. "keyReference" - a String, with the reference to the external "key material". Written only if internalStorage is false. + * + * If internalStorage is true, "key material" is a part of "key metadata", and the json keeps additional fields, described in the + * KeyMaterial class. + */ +public class KeyMetadata +{ + static final String KEY_MATERIAL_INTERNAL_STORAGE_FIELD = "internalStorage"; + private static final String KEY_REFERENCE_FIELD = "keyReference"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final boolean isInternalStorage; + private final String keyReference; + private final KeyMaterial keyMaterial; + + private KeyMetadata(boolean isInternalStorage, String keyReference, KeyMaterial keyMaterial) + { + this.isInternalStorage = isInternalStorage; + this.keyReference = keyReference; + this.keyMaterial = keyMaterial; + } + + static KeyMetadata parse(byte[] keyMetadataBytes) + { + String keyMetaDataString = new String(keyMetadataBytes, StandardCharsets.UTF_8); + Map keyMetadataJson = null; + try { + keyMetadataJson = OBJECT_MAPPER.readValue( + new StringReader(keyMetaDataString), new TypeReference>() {}); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to parse key metadata " + keyMetaDataString, e); + } + + // 1. Extract "key material type", and make sure it is supported + String keyMaterialType = (String) keyMetadataJson.get(KeyMaterial.KEY_MATERIAL_TYPE_FIELD); + if (!KeyMaterial.KEY_MATERIAL_TYPE1.equals(keyMaterialType)) { + throw new ParquetCryptoRuntimeException( + "Wrong key material type: " + keyMaterialType + " vs " + KeyMaterial.KEY_MATERIAL_TYPE1); + } + + // 2. Check if "key material" is stored internally in Parquet file key metadata, or is stored externally + Boolean isInternalStorage = (Boolean) keyMetadataJson.get(KEY_MATERIAL_INTERNAL_STORAGE_FIELD); + String keyReference; + KeyMaterial keyMaterial; + + if (isInternalStorage) { + // 3.1 "key material" is stored internally, inside "key metadata" - parse it + keyMaterial = KeyMaterial.parse(keyMetadataJson); + keyReference = null; + } + else { + // 3.2 "key material" is stored externally. "key metadata" keeps a reference to it + keyReference = (String) keyMetadataJson.get(KEY_REFERENCE_FIELD); + keyMaterial = null; + } + + return new KeyMetadata(isInternalStorage, keyReference, keyMaterial); + } + + // For external material only. For internal material, create serialized KeyMaterial directly + static String createSerializedForExternalMaterial(String keyReference) + { + Map keyMetadataMap = new HashMap(3); + // 1. Write "key material type" + keyMetadataMap.put(KeyMaterial.KEY_MATERIAL_TYPE_FIELD, KeyMaterial.KEY_MATERIAL_TYPE1); + // 2. Write internal storage as false + keyMetadataMap.put(KEY_MATERIAL_INTERNAL_STORAGE_FIELD, Boolean.FALSE); + // 3. For externally stored "key material", "key metadata" keeps only a reference to it + keyMetadataMap.put(KEY_REFERENCE_FIELD, keyReference); + + try { + return OBJECT_MAPPER.writeValueAsString(keyMetadataMap); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to serialize key metadata", e); + } + } + + boolean keyMaterialStoredInternally() + { + return isInternalStorage; + } + + KeyMaterial getKeyMaterial() + { + return keyMaterial; + } + + String getKeyReference() + { + return keyReference; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java new file mode 100644 index 000000000000..0c5bee3b6da8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java @@ -0,0 +1,164 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.crypto.keytools; + +import com.google.common.base.Strings; +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KeyWithMasterID; + +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KEK_READ_CACHE_PER_TOKEN; +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KMS_CLIENT_CACHE_PER_TOKEN; + +public class TrinoFileKeyUnwrapper + implements DecryptionKeyRetriever +{ + private static final Logger LOG = Logger.get(TrinoFileKeyUnwrapper.class); + + //A map of KEK_ID -> KEK bytes, for the current token + private final ConcurrentMap kekPerKekID; + private final Location parquetFilePath; + // TODO(wyu): shall we get it from Location or File + private final TrinoFileSystem trinoFileSystem; + private final String accessToken; + private final long cacheEntryLifetime; + private final ParquetReaderOptions parquetReaderOptions; + private TrinoKeyToolkit.TrinoKmsClientAndDetails kmsClientAndDetails; + private TrinoHadoopFSKeyMaterialStore keyMaterialStore; + private boolean checkedKeyMaterialInternalStorage; + + TrinoFileKeyUnwrapper(ParquetReaderOptions conf, Location filePath, TrinoFileSystem trinoFileSystem) + { + this.trinoFileSystem = trinoFileSystem; + this.parquetReaderOptions = conf; + this.parquetFilePath = filePath; + this.cacheEntryLifetime = 1000L * conf.getEncryptionCacheLifetimeSeconds(); + this.accessToken = conf.getEncryptionKeyAccessToken(); + this.kmsClientAndDetails = null; + this.keyMaterialStore = null; + this.checkedKeyMaterialInternalStorage = false; + + // Check cache upon each file reading (clean once in cacheEntryLifetime) + KMS_CLIENT_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + KEK_READ_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + kekPerKekID = KEK_READ_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + if (LOG.isDebugEnabled()) { + LOG.debug("Creating file key unwrapper. KeyMaterialStore: {}; token snippet: {}", + keyMaterialStore, TrinoKeyToolkit.formatTokenForLog(accessToken)); + } + } + + @Override + public byte[] getKey(byte[] keyMetadataBytes) + { + KeyMetadata keyMetadata = KeyMetadata.parse(keyMetadataBytes); + + if (!checkedKeyMaterialInternalStorage) { + if (!keyMetadata.keyMaterialStoredInternally()) { + keyMaterialStore = new TrinoHadoopFSKeyMaterialStore(trinoFileSystem, parquetFilePath, false); + } + checkedKeyMaterialInternalStorage = true; + } + + KeyMaterial keyMaterial; + if (keyMetadata.keyMaterialStoredInternally()) { + // Internal key material storage: key material is inside key metadata + keyMaterial = keyMetadata.getKeyMaterial(); + } + else { + // External key material storage: key metadata contains a reference to a key in the material store + String keyIDinFile = keyMetadata.getKeyReference(); + String keyMaterialString = keyMaterialStore.getKeyMaterial(keyIDinFile); + if (null == keyMaterialString) { + throw new ParquetCryptoRuntimeException("Null key material for keyIDinFile: " + keyIDinFile); + } + keyMaterial = KeyMaterial.parse(keyMaterialString); + } + + return getDEKandMasterID(keyMaterial).getDataKey(); + } + + KeyWithMasterID getDEKandMasterID(KeyMaterial keyMaterial) + { + if (null == kmsClientAndDetails) { + kmsClientAndDetails = getKmsClientFromConfigOrKeyMaterial(keyMaterial); + } + + boolean doubleWrapping = keyMaterial.isDoubleWrapped(); + String masterKeyID = keyMaterial.getMasterKeyID(); + String encodedWrappedDEK = keyMaterial.getWrappedDEK(); + + byte[] dataKey; + TrinoKmsClient kmsClient = kmsClientAndDetails.getKmsClient(); + if (!doubleWrapping) { + dataKey = kmsClient.unwrapKey(encodedWrappedDEK, masterKeyID); + } + else { + // Get KEK + String encodedKekID = keyMaterial.getKekID(); + String encodedWrappedKEK = keyMaterial.getWrappedKEK(); + + byte[] kekBytes = kekPerKekID.computeIfAbsent(encodedKekID, + (k) -> kmsClient.unwrapKey(encodedWrappedKEK, masterKeyID)); + + if (null == kekBytes) { + throw new ParquetCryptoRuntimeException("Null KEK, after unwrapping in KMS with master key " + masterKeyID); + } + + // Decrypt the data key + byte[] aad = Base64.getDecoder().decode(encodedKekID); + dataKey = TrinoKeyToolkit.decryptKeyLocally(encodedWrappedDEK, kekBytes, aad); + } + + return new KeyWithMasterID(dataKey, masterKeyID); + } + + TrinoKeyToolkit.TrinoKmsClientAndDetails getKmsClientFromConfigOrKeyMaterial(KeyMaterial keyMaterial) + { + String kmsInstanceID = this.parquetReaderOptions.getEncryptionKmsInstanceId(); + if (Strings.isNullOrEmpty(kmsInstanceID)) { + kmsInstanceID = keyMaterial.getKmsInstanceID(); + if (null == kmsInstanceID) { + throw new ParquetCryptoRuntimeException("KMS instance ID is missing both in properties and file key material"); + } + } + + String kmsInstanceURL = this.parquetReaderOptions.getEncryptionKmsInstanceUrl(); + if (Strings.isNullOrEmpty(kmsInstanceURL)) { + kmsInstanceURL = keyMaterial.getKmsInstanceURL(); + if (null == kmsInstanceURL) { + throw new ParquetCryptoRuntimeException("KMS instance URL is missing both in properties and file key material"); + } + } + + TrinoKmsClient kmsClient = TrinoKeyToolkit.getKmsClient(kmsInstanceID, kmsInstanceURL, this.parquetReaderOptions, accessToken, cacheEntryLifetime); + if (null == kmsClient) { + throw new ParquetCryptoRuntimeException("KMSClient was not successfully created for reading encrypted data."); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("File unwrapper - KmsClient: {}; InstanceId: {}; InstanceURL: {}", kmsClient, kmsInstanceID, kmsInstanceURL); + } + return new TrinoKeyToolkit.TrinoKmsClientAndDetails(kmsClient, kmsInstanceID, kmsInstanceURL); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java new file mode 100644 index 000000000000..4c178c0bd8fe --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Map; + +public class TrinoHadoopFSKeyMaterialStore +{ + public static final String KEY_MATERIAL_FILE_PREFIX = "_KEY_MATERIAL_FOR_"; + public static final String TEMP_FILE_PREFIX = "_TMP"; + public static final String KEY_MATERIAL_FILE_SUFFFIX = ".json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private TrinoFileSystem trinoFileSystem; + private Map keyMaterialMap; + private Location keyMaterialFile; + + TrinoHadoopFSKeyMaterialStore(TrinoFileSystem trinoFileSystem, Location parquetFilePath, boolean tempStore) + { + this.trinoFileSystem = trinoFileSystem; + String fullPrefix = (tempStore ? TEMP_FILE_PREFIX : ""); + fullPrefix += KEY_MATERIAL_FILE_PREFIX; + keyMaterialFile = parquetFilePath.parentDirectory().appendPath( + fullPrefix + parquetFilePath.fileName() + KEY_MATERIAL_FILE_SUFFFIX); + } + + public String getKeyMaterial(String keyIDInFile) + throws ParquetCryptoRuntimeException + { + if (null == keyMaterialMap) { + loadKeyMaterialMap(); + } + return keyMaterialMap.get(keyIDInFile); + } + + private void loadKeyMaterialMap() + { + TrinoInputFile inputfile = trinoFileSystem.newInputFile(keyMaterialFile); + try (TrinoInputStream keyMaterialStream = inputfile.newStream()) { + JsonNode keyMaterialJson = objectMapper.readTree(keyMaterialStream); + keyMaterialMap = objectMapper.readValue(keyMaterialJson.traverse(), + new TypeReference>() {}); + } + catch (FileNotFoundException e) { + throw new ParquetCryptoRuntimeException("External key material not found at " + keyMaterialFile, e); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to get key material from " + keyMaterialFile, e); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java new file mode 100644 index 000000000000..eb05702732ba --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.AesGcmDecryptor; +import io.trino.parquet.crypto.AesMode; +import io.trino.parquet.crypto.ModuleCipherFactory; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoCryptoConfigurationUtil; + +import java.lang.reflect.InvocationTargetException; +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +public class TrinoKeyToolkit +{ + public static final long CACHE_LIFETIME_DEFAULT_SECONDS = 10 * 60; // 10 minutes + + // KMS client two level cache: token -> KMSInstanceId -> KmsClient + static final TwoLevelCacheWithExpiration KMS_CLIENT_CACHE_PER_TOKEN = + KmsClientCache.INSTANCE.getCache(); + + // KEK two level cache for unwrapping: token -> KEK_ID -> KEK bytes + static final TwoLevelCacheWithExpiration KEK_READ_CACHE_PER_TOKEN = + KEKReadCache.INSTANCE.getCache(); + + private TrinoKeyToolkit() + { + } + + private enum KmsClientCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + private enum KEKReadCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + static String formatTokenForLog(String accessToken) + { + int maxTokenDisplayLength = 5; + if (accessToken.length() <= maxTokenDisplayLength) { + return accessToken; + } + return accessToken.substring(accessToken.length() - maxTokenDisplayLength); + } + + static class KeyWithMasterID + { + private final byte[] keyBytes; + private final String masterID; + + KeyWithMasterID(byte[] keyBytes, String masterID) + { + this.keyBytes = keyBytes; + this.masterID = masterID; + } + + byte[] getDataKey() + { + return keyBytes; + } + + String getMasterID() + { + return masterID; + } + } + + static class KeyEncryptionKey + { + private final byte[] kekBytes; + private final byte[] kekID; + private String encodedKekID; + private final String encodedWrappedKEK; + + KeyEncryptionKey(byte[] kekBytes, byte[] kekID, String encodedWrappedKEK) + { + this.kekBytes = kekBytes; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + } + + byte[] getBytes() + { + return kekBytes; + } + + byte[] getID() + { + return kekID; + } + + String getEncodedID() + { + if (null == encodedKekID) { + encodedKekID = Base64.getEncoder().encodeToString(kekID); + } + return encodedKekID; + } + + String getEncodedWrappedKEK() + { + return encodedWrappedKEK; + } + } + + /** + * Decrypts encrypted key with "masterKey", using AES-GCM and the "aad" + * + * @param encodedEncryptedKey base64 encoded encrypted key + * @param masterKeyBytes encryption key + * @param aad additional authenticated data + * @return decrypted key + */ + public static byte[] decryptKeyLocally(String encodedEncryptedKey, byte[] masterKeyBytes, byte[] aad) + { + byte[] encryptedKey = Base64.getDecoder().decode(encodedEncryptedKey); + + AesGcmDecryptor keyDecryptor; + + keyDecryptor = (AesGcmDecryptor) ModuleCipherFactory.getDecryptor(AesMode.GCM, masterKeyBytes); + + return keyDecryptor.decrypt(encryptedKey, 0, encryptedKey.length, aad); + } + + static TrinoKmsClient getKmsClient(String kmsInstanceID, String kmsInstanceURL, ParquetReaderOptions trinoParquetCryptoConfig, + String accessToken, long cacheEntryLifetime) + { + ConcurrentMap kmsClientPerKmsInstanceCache = + KMS_CLIENT_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + TrinoKmsClient kmsClient = + kmsClientPerKmsInstanceCache.computeIfAbsent(kmsInstanceID, + (k) -> createAndInitKmsClient(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken)); + + return kmsClient; + } + + private static TrinoKmsClient createAndInitKmsClient(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, + String kmsInstanceURL, String accessToken) + { + Class> kmsClientClass = null; + TrinoKmsClient kmsClient; + + try { + kmsClientClass = TrinoCryptoConfigurationUtil.getClassFromConfig(trinoParquetCryptoConfig.getEncryptionKmsClientClass(), + TrinoKmsClient.class); + + if (null == kmsClientClass) { + throw new ParquetCryptoRuntimeException("Could not find class " + trinoParquetCryptoConfig.getEncryptionKmsClientClass()); + } + kmsClient = (TrinoKmsClient) kmsClientClass.getConstructor().newInstance(); + } + catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + throw new ParquetCryptoRuntimeException("Could not instantiate KmsClient class: " + + kmsClientClass, e); + } + + kmsClient.initialize(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken); + + return kmsClient; + } + + static class TrinoKmsClientAndDetails + { + public TrinoKmsClient getKmsClient() + { + return kmsClient; + } + + private TrinoKmsClient kmsClient; + private String kmsInstanceID; + private String kmsInstanceURL; + + public TrinoKmsClientAndDetails(TrinoKmsClient kmsClient, String kmsInstanceID, String kmsInstanceURL) + { + this.kmsClient = kmsClient; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + } + + public String getKmsInstanceID() + { + return kmsInstanceID; + } + + public String getKmsInstanceURL() + { + return kmsInstanceURL; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java new file mode 100644 index 000000000000..6ca6cb0cb53e --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.KeyAccessDeniedException; + +public interface TrinoKmsClient +{ + String KEY_ACCESS_TOKEN_DEFAULT = "DEFAULT"; + + void initialize(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, String kmsInstanceURL, String accessToken) + throws KeyAccessDeniedException; + + String wrapKey(byte[] keyBytes, String masterKeyIdentifier) + throws KeyAccessDeniedException; + + byte[] unwrapKey(String wrappedKey, String masterKeyIdentifier) + throws KeyAccessDeniedException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java new file mode 100644 index 000000000000..8eb61c18c0e8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoDecryptionPropertiesFactory; + +public class TrinoPropertiesDrivenCryptoFactory + implements TrinoDecryptionPropertiesFactory +{ + private static final Logger LOG = Logger.get(TrinoPropertiesDrivenCryptoFactory.class); + + @Override + public FileDecryptionProperties getFileDecryptionProperties(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException + { + DecryptionKeyRetriever keyRetriever = new TrinoFileKeyUnwrapper(parquetReaderOptions, filePath, trinoFileSystem); + + if (LOG.isDebugEnabled()) { + LOG.debug("File decryption properties for {}", filePath); + } + + return FileDecryptionProperties.builder() + .withKeyRetriever(keyRetriever) + .withPlaintextFilesAllowed() + .build(); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java new file mode 100644 index 000000000000..ca2e7d2d356d --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * Concurrent two-level cache with expiration of internal caches according to token lifetime. + * External cache is per token, internal is per String key. + * + * @param Value + */ +class TwoLevelCacheWithExpiration +{ + private final ConcurrentMap>> cache; + private volatile long lastCacheCleanupTimestamp; + + TwoLevelCacheWithExpiration() + { + this.cache = new ConcurrentHashMap<>(); + this.lastCacheCleanupTimestamp = System.currentTimeMillis(); + } + + ConcurrentMap getOrCreateInternalCache(String accessToken, long cacheEntryLifetime) + { + ExpiringCacheEntry> externalCacheEntry = + cache.compute(accessToken, (token, cacheEntry) -> { + if ((null == cacheEntry) || cacheEntry.isExpired()) { + return new ExpiringCacheEntry<>(new ConcurrentHashMap(), cacheEntryLifetime); + } + else { + return cacheEntry; + } + }); + return externalCacheEntry.getCachedItem(); + } + + void removeCacheEntriesForToken(String accessToken) + { + cache.remove(accessToken); + } + + void removeCacheEntriesForAllTokens() + { + cache.clear(); + } + + public void checkCacheForExpiredTokens(long cacheCleanupPeriod) + { + long now = System.currentTimeMillis(); + + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + synchronized (cache) { + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + removeExpiredEntriesFromCache(); + lastCacheCleanupTimestamp = now + cacheCleanupPeriod; + } + } + } + } + + public void removeExpiredEntriesFromCache() + { + cache.values().removeIf(cacheEntry -> cacheEntry.isExpired()); + } + + public void remove(String accessToken) + { + cache.remove(accessToken); + } + + public void clear() + { + cache.clear(); + } + + static class ExpiringCacheEntry + { + private final long expirationTimestamp; + private final E cachedItem; + + private ExpiringCacheEntry(E cachedItem, long expirationIntervalMillis) + { + this.expirationTimestamp = System.currentTimeMillis() + expirationIntervalMillis; + this.cachedItem = cachedItem; + } + + private boolean isExpired() + { + final long now = System.currentTimeMillis(); + return (now > expirationTimestamp); + } + + private E getCachedItem() + { + return cachedItem; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java index 43defc21b834..1a955515fe50 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java @@ -15,7 +15,7 @@ import java.util.List; -public record BlockMetadata(long rowCount, List columns) +public record BlockMetadata(long rowCount, long totalByteSize, short ordinal, List columns) { public long getStartingPos() { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java index 381260829869..0c9c85c95aee 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java @@ -23,6 +23,9 @@ import java.util.Set; +import static io.trino.parquet.ParquetEncoding.PLAIN_DICTIONARY; +import static io.trino.parquet.ParquetEncoding.RLE_DICTIONARY; + public abstract class ColumnChunkMetadata { protected int rowGroupOrdinal = -1; @@ -200,4 +203,16 @@ public String toString() decryptIfNeeded(); return "ColumnMetaData{" + properties.toString() + ", " + getFirstDataPageOffset() + "}"; } + + public boolean hasDictionaryPage() + { + EncodingStats stats = getEncodingStats(); + if (stats != null) { + // ensure there is a dictionary page and that it is used to encode data pages + return stats.hasDictionaryPages() && stats.hasDictionaryEncodedPages(); + } + + Set encodings = getEncodings(); + return (encodings.contains(PLAIN_DICTIONARY) || encodings.contains(RLE_DICTIONARY)); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java index 6901bb23a4e6..3293e980e719 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; @@ -230,9 +231,11 @@ private static Map> getStatistics(PrunedBlockMet ImmutableMap.Builder> statistics = ImmutableMap.builderWithExpectedSize(descriptorsByPath.size()); for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - Statistics> columnStatistics = columnMetaData.getStatistics(); - if (columnStatistics != null) { - statistics.put(descriptor, columnStatistics); + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + if (columnStatistics != null) { + statistics.put(descriptor, columnStatistics); + } } } return statistics.buildOrThrow(); @@ -260,18 +263,20 @@ private static boolean dictionaryPredicatesMatch( { for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - if (!candidateColumns.contains(descriptor)) { - continue; - } - if (isOnlyDictionaryEncodingPages(columnMetaData)) { - Statistics> columnStatistics = columnMetaData.getStatistics(); - boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; - // Early abort, predicate already filters block so no more dictionaries need be read - if (!parquetPredicate.matches(new DictionaryDescriptor( - descriptor, - nullAllowed, - readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { - return false; + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + if (!candidateColumns.contains(descriptor)) { + continue; + } + if (isOnlyDictionaryEncodingPages(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; + // Early abort, predicate already filters block so no more dictionaries need be read + if (!parquetPredicate.matches(new DictionaryDescriptor( + descriptor, + nullAllowed, + readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { + return false; + } } } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index fe0635646f98..294cfe0604b2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -15,26 +15,41 @@ import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; +import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.AesGcmEncryptor; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.KeyAccessDeniedException; +import io.trino.parquet.crypto.ModuleCipherFactory.ModuleType; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TagVerificationException; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.format.BlockCipher.Decryptor; import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnMetaData; import org.apache.parquet.format.Encoding; +import org.apache.parquet.format.EncryptionWithColumnKey; +import org.apache.parquet.format.FileCryptoMetaData; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.SchemaElement; import org.apache.parquet.format.Statistics; +import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -43,6 +58,7 @@ import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -56,7 +72,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetMetadataConverter.convertEncodingStats; import static io.trino.parquet.ParquetMetadataConverter.fromParquetStatistics; import static io.trino.parquet.ParquetMetadataConverter.getEncoding; @@ -69,6 +87,7 @@ import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static org.apache.parquet.format.Util.readFileCryptoMetaData; import static org.apache.parquet.format.Util.readFileMetaData; public final class MetadataReader @@ -76,13 +95,14 @@ public final class MetadataReader private static final Logger log = Logger.get(MetadataReader.class); private static final Slice MAGIC = Slices.utf8Slice("PAR1"); + private static final Slice EMAGIC = Slices.utf8Slice("PARE"); private static final int POST_SCRIPT_SIZE = Integer.BYTES + MAGIC.length(); // Typical 1GB files produced by Trino were found to have footer size between 30-40KB private static final int EXPECTED_FOOTER_SIZE = 48 * 1024; private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation, Optional fileDecryptor) throws IOException { // Parquet File Layout: @@ -93,7 +113,9 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< // 4 bytes: MetadataLength // MAGIC - validateParquet(dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE, dataSource.getId(), "%s is not a valid Parquet File", dataSource.getId()); + validateParquet((dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE) || + (dataSource.getEstimatedSize() >= EMAGIC.length() + POST_SCRIPT_SIZE), dataSource.getId(), + "%s is not a valid Parquet File", dataSource.getId()); // Read the tail of the file long estimatedFileSize = dataSource.getEstimatedSize(); @@ -101,8 +123,10 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< Slice buffer = dataSource.readTail(toIntExact(expectedReadSize)); Slice magic = buffer.slice(buffer.length() - MAGIC.length(), MAGIC.length()); - validateParquet(MAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + validateParquet(MAGIC.equals(magic) || EMAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s or %s got: %s", MAGIC.toStringUtf8(), EMAGIC.toStringUtf8(), magic.toStringUtf8()); + boolean encryptedFooterMode = EMAGIC.equals(magic); + checkArgument(!encryptedFooterMode || !(fileDecryptor.isEmpty() || fileDecryptor.get().getDecryptionProperties() == null), "fileDecryptionProperties cannot be null when encryptedFooterMode is true"); int metadataLength = buffer.getInt(buffer.length() - POST_SCRIPT_SIZE); long metadataIndex = estimatedFileSize - POST_SCRIPT_SIZE - metadataLength; validateParquet( @@ -118,13 +142,44 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< } InputStream metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); - FileMetaData fileMetaData = readFileMetaData(metadataStream); - ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, dataSource.getId()); + Decryptor footerDecryptor = null; + byte[] aad = null; + + if (encryptedFooterMode) { + FileCryptoMetaData fileCryptoMetaData = readFileCryptoMetaData(metadataStream); + fileDecryptor.get().setFileCryptoMetaData(fileCryptoMetaData.getEncryption_algorithm(), true, fileCryptoMetaData.getKey_metadata()); + footerDecryptor = fileDecryptor.get().fetchFooterDecryptor(); + aad = AesCipher.createFooterAAD(fileDecryptor.get().getFileAAD()); + } + FileMetaData fileMetaData = readFileMetaData(metadataStream, footerDecryptor, aad); + if (!encryptedFooterMode && fileDecryptor.isPresent()) { + if (!fileMetaData.isSetEncryption_algorithm()) { // Plaintext file + fileDecryptor.get().setPlaintextFile(); + // Done to detect files that were not encrypted by mistake + if (!fileDecryptor.get().plaintextFilesAllowed()) { + throw new ParquetCryptoRuntimeException("Applying decryptor on plaintext file"); + } + } + else { // Encrypted file with plaintext footer + // if no fileDecryptor, can still read plaintext columns + fileDecryptor.get().setFileCryptoMetaData(fileMetaData.getEncryption_algorithm(), false, + fileMetaData.getFooter_signing_key_metadata()); + if (fileDecryptor.get().checkFooterIntegrity()) { + verifyFooterIntegrity(metadataStream, fileDecryptor.get(), metadataLength); + } + } + } + ParquetDataSourceId id = dataSource.getId(); + ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, id, fileDecryptor, encryptedFooterMode); + validateFileMetadata(id, parquetMetadata.getFileMetaData(), parquetWriteValidation); validateFileMetadata(dataSource.getId(), parquetMetadata.getFileMetaData(), parquetWriteValidation); return parquetMetadata; } - public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, ParquetDataSourceId dataSourceId) + public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, + ParquetDataSourceId dataSourceId, + Optional fileDecryptor, + boolean encryptedFooterMode) throws ParquetCorruptionException { List schema = fileMetaData.getSchema(); @@ -138,37 +193,79 @@ public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, P List columns = rowGroup.getColumns(); validateParquet(!columns.isEmpty(), dataSourceId, "No columns in row group: %s", rowGroup); String filePath = columns.get(0).getFile_path(); + int columnOrdinal = -1; ImmutableList.Builder columnMetadataBuilder = ImmutableList.builderWithExpectedSize(columns.size()); for (ColumnChunk columnChunk : columns) { + columnOrdinal++; validateParquet( (filePath == null && columnChunk.getFile_path() == null) || (filePath != null && filePath.equals(columnChunk.getFile_path())), dataSourceId, "all column chunks of the same row group must be in the same file"); + ColumnCryptoMetaData cryptoMetaData = columnChunk.getCrypto_metadata(); ColumnMetaData metaData = columnChunk.meta_data; - String[] path = metaData.path_in_schema.stream() - .map(value -> value.toLowerCase(Locale.ENGLISH)) - .toArray(String[]::new); - ColumnPath columnPath = ColumnPath.get(path); - PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); - ColumnChunkMetadata column = ColumnChunkMetadata.get( - columnPath, - primitiveType, - CompressionCodecName.fromParquet(metaData.codec), - convertEncodingStats(metaData.encoding_stats), - readEncodings(metaData.encodings), - readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), - metaData.data_page_offset, - metaData.dictionary_page_offset, - metaData.num_values, - metaData.total_compressed_size, - metaData.total_uncompressed_size); - column.setColumnIndexReference(toColumnIndexReference(columnChunk)); - column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); - column.setBloomFilterOffset(metaData.bloom_filter_offset); - columnMetadataBuilder.add(column); + ColumnPath columnPath = null; + boolean encryptedMetadata = false; + if (cryptoMetaData == null) { + columnPath = getPath(metaData); + if (fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile()) { + // mark this column as plaintext in encrypted file decryptor + fileDecryptor.get().setColumnCryptoMetadata(columnPath, false, false, (byte[]) null, columnOrdinal); + } + } + else { // Encrypted column + if (cryptoMetaData.isSetENCRYPTION_WITH_FOOTER_KEY()) { // Column encrypted with footer key + if (!encryptedFooterMode) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key in file with plaintext footer"); + } + if (null == metaData) { + throw new ParquetCryptoRuntimeException("ColumnMetaData not set in Encryption with Footer key"); + } + if (fileDecryptor.isEmpty()) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key: No keys available"); + } + columnPath = getPath(metaData); + fileDecryptor.get().setColumnCryptoMetadata(columnPath, true, true, (byte[]) null, columnOrdinal); + } + else { // Column encrypted with column key + encryptedMetadata = true; + } + } + try { + if (encryptedMetadata) { + // TODO: We decrypted data before filter projection. This could send unnecessary traffic to KMS. + // In parquet-mr, it uses lazy decyrption but that required to change ColumnChunkMetadata. We will improve it alter. + metaData = decryptMetadata(rowGroup, cryptoMetaData, columnChunk, fileDecryptor.get(), columnOrdinal); + columnPath = getPath(metaData); + } + PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); + ColumnChunkMetadata column = ColumnChunkMetadata.get( + columnPath, + primitiveType, + CompressionCodecName.fromParquet(metaData.codec), + convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + column.setColumnIndexReference(toColumnIndexReference(columnChunk)); + column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); + column.setBloomFilterOffset(metaData.bloom_filter_offset); + + if (rowGroup.isSetOrdinal()) { + column.setRowGroupOrdinal(rowGroup.getOrdinal()); + } + columnMetadataBuilder.add(column); + } + catch (KeyAccessDeniedException e) { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(columnPath, filePath); + columnMetadataBuilder.add(column); + } } - blocks.add(new BlockMetadata(rowGroup.getNum_rows(), columnMetadataBuilder.build())); + blocks.add(new BlockMetadata(rowGroup.getNum_rows(), rowGroup.getTotal_byte_size(), rowGroup.getOrdinal(), columnMetadataBuilder.build())); } } @@ -274,6 +371,25 @@ public static org.apache.parquet.column.statistics.Statistics> readStats(Optio return columnStatistics; } + /** + * If a column is encrypted and user doesn't provide correct key to decrypt, that column is hidden to current request. + * This method find out the first non-hidden column. + * + * @param block BlockMetaData + * @return first non hidden column id. + */ + public static Integer findFirstNonHiddenColumnId(BlockMetadata block) + { + List columns = block.columns(); + for (int i = 0; i < columns.size(); i++) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(columns.get(i))) { + return i; + } + } + // all columns are hidden (encrypted but not accessible to current user) + return null; + } + private static boolean isStringType(PrimitiveType type) { if (type.getLogicalTypeAnnotation() == null) { @@ -373,4 +489,75 @@ private static void validateFileMetadata(ParquetDataSourceId dataSourceId, FileM Optional.ofNullable(fileMetaData.getKeyValueMetaData().get("writer.time.zone"))); writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); } + + private static ColumnMetaData decryptMetadata(RowGroup rowGroup, ColumnCryptoMetaData cryptoMetaData, ColumnChunk columnChunk, InternalFileDecryptor fileDecryptor, int columnOrdinal) + { + EncryptionWithColumnKey columnKeyStruct = cryptoMetaData.getENCRYPTION_WITH_COLUMN_KEY(); + List pathList = columnKeyStruct.getPath_in_schema().stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()); + + byte[] columnKeyMetadata = columnKeyStruct.getKey_metadata(); + ColumnPath columnPath = ColumnPath.get(pathList.toArray(new String[pathList.size()])); + byte[] encryptedMetadataBuffer = columnChunk.getEncrypted_column_metadata(); + + // Decrypt the ColumnMetaData + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.setColumnCryptoMetadata(columnPath, true, false, columnKeyMetadata, columnOrdinal); + ByteArrayInputStream tempInputStream = new ByteArrayInputStream(encryptedMetadataBuffer); + byte[] columnMetaDataAAD = AesCipher.createModuleAAD(fileDecryptor.getFileAAD(), ModuleType.ColumnMetaData, rowGroup.ordinal, columnOrdinal, -1); + try { + return Util.readColumnMetaData(tempInputStream, columnDecryptionSetup.getMetaDataDecryptor(), columnMetaDataAAD); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException(columnPath + ". Failed to decrypt column metadata", e); + } + } + + /*public static ColumnChunkMetadata buildColumnChunkMetaData(Optional fileCreatedBy, ColumnMetaData metaData, ColumnPath columnPath, PrimitiveType type) + { + return ColumnChunkMetadata.get( + columnPath, + type, + CompressionCodecName.fromParquet(metaData.codec), + PARQUET_METADATA_CONVERTER.convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(fileCreatedBy, Optional.ofNullable(metaData.statistics), type), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + }*/ + + private static ColumnPath getPath(ColumnMetaData metaData) + { + String[] path = metaData.path_in_schema.stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)) + .toArray(String[]::new); + return ColumnPath.get(path); + } + + private static void verifyFooterIntegrity(InputStream metadataStream, InternalFileDecryptor fileDecryptor, int combinedFooterLength) + throws IOException + { + byte[] nonce = new byte[AesCipher.NONCE_LENGTH]; + metadataStream.read(nonce); + byte[] gcmTag = new byte[AesCipher.GCM_TAG_LENGTH]; + metadataStream.read(gcmTag); + + AesGcmEncryptor footerSigner = fileDecryptor.createSignedFooterEncryptor(); + int footerSignatureLength = AesCipher.NONCE_LENGTH + AesCipher.GCM_TAG_LENGTH; + byte[] serializedFooter = new byte[combinedFooterLength - footerSignatureLength]; + + //InputStream doesn't implement reset(). Here is to workaround + ((BasicSliceInput) metadataStream).setPosition(0); + metadataStream.read(serializedFooter, 0, serializedFooter.length); + + byte[] signedFooterAAD = AesCipher.createFooterAAD(fileDecryptor.getFileAAD()); + byte[] encryptedFooterBytes = footerSigner.encrypt(false, serializedFooter, nonce, signedFooterAAD); + byte[] calculatedTag = new byte[AesCipher.GCM_TAG_LENGTH]; + System.arraycopy(encryptedFooterBytes, encryptedFooterBytes.length - AesCipher.GCM_TAG_LENGTH, calculatedTag, 0, AesCipher.GCM_TAG_LENGTH); + if (!Arrays.equals(gcmTag, calculatedTag)) { + throw new TagVerificationException("Signature mismatch in plaintext footer"); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index d8ec35c52fbe..799d4b111654 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -16,17 +16,24 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; import io.trino.parquet.DataPage; import io.trino.parquet.DataPageV1; import io.trino.parquet.DataPageV2; import io.trino.parquet.DictionaryPage; import io.trino.parquet.Page; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.CompressionCodec; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -35,6 +42,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetCompressionUtils.decompress; import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages; import static java.util.Objects.requireNonNull; @@ -49,6 +57,10 @@ public final class PageReader private boolean dictionaryAlreadyRead; private int dataPageReadCount; + private int pageIndex; + private final BlockCipher.Decryptor blockDecryptor; + private byte[] dataPageAAD; + private byte[] dictionaryPageAAD; public static PageReader createPageReader( ParquetDataSourceId dataSourceId, @@ -56,7 +68,8 @@ public static PageReader createPageReader( ColumnChunkMetadata metadata, ColumnDescriptor columnDescriptor, @Nullable OffsetIndex offsetIndex, - Optional fileCreatedBy) + Optional fileCreatedBy, + Optional fileDecryptor) { // Parquet schema may specify a column definition as OPTIONAL even though there are no nulls in the actual data. // Row-group column statistics can be used to identify such cases and switch to faster non-nullable read @@ -64,20 +77,36 @@ public static PageReader createPageReader( Statistics> columnStatistics = metadata.getStatistics(); boolean hasNoNulls = columnStatistics != null && columnStatistics.getNumNulls() == 0; boolean hasOnlyDictionaryEncodedPages = isOnlyDictionaryEncodingPages(metadata); + byte[] fileAad = null; + BlockCipher.Decryptor dataDecryptor = null; + int columnOrdinal = -1; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + fileAad = fileDecryptor.get().getFileAAD(); + dataDecryptor = columnDecryptionSetup.getDataDecryptor(); + columnOrdinal = columnDecryptionSetup.getOrdinal(); + } ParquetColumnChunkIterator compressedPages = new ParquetColumnChunkIterator( dataSourceId, fileCreatedBy, columnDescriptor, metadata, columnChunk, - offsetIndex); + offsetIndex, + fileDecryptor, + columnOrdinal); return new PageReader( dataSourceId, metadata.getCodec().getParquetCompressionCodec(), compressedPages, hasOnlyDictionaryEncodedPages, - hasNoNulls); + hasNoNulls, + dataDecryptor, + fileAad, + metadata.getRowGroupOrdinal(), + columnOrdinal); } @VisibleForTesting @@ -86,13 +115,22 @@ public PageReader( CompressionCodec codec, Iterator extends Page> compressedPages, boolean hasOnlyDictionaryEncodedPages, - boolean hasNoNulls) + boolean hasNoNulls, + BlockCipher.Decryptor blockDecryptor, + byte[] fileAAD, + int rowGroupOrdinal, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.codec = codec; this.compressedPages = Iterators.peekingIterator(compressedPages); this.hasOnlyDictionaryEncodedPages = hasOnlyDictionaryEncodedPages; this.hasNoNulls = hasNoNulls; + this.blockDecryptor = blockDecryptor; + if (null != blockDecryptor) { + dataPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DataPage, rowGroupOrdinal, columnOrdinal, 0); + dictionaryPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DictionaryPage, rowGroupOrdinal, columnOrdinal, -1); + } } public boolean hasNoNulls() @@ -114,18 +152,23 @@ public DataPage readPage() checkState(compressedPage instanceof DataPage, "Found page %s instead of a DataPage", compressedPage); dataPageReadCount++; try { + if (null != blockDecryptor) { + AesCipher.quickUpdatePageAAD(dataPageAAD, ((DataPage) compressedPage).getPageIndex()); + } + Slice slice = decryptSliceIfNeeded(compressedPage.getSlice(), dataPageAAD); if (compressedPage instanceof DataPageV1 dataPageV1) { if (!arePagesCompressed()) { return dataPageV1; } return new DataPageV1( - decompress(dataSourceId, codec, dataPageV1.getSlice(), dataPageV1.getUncompressedSize()), + decompress(dataSourceId, codec, slice, dataPageV1.getUncompressedSize()), dataPageV1.getValueCount(), dataPageV1.getUncompressedSize(), dataPageV1.getFirstRowIndex(), dataPageV1.getRepetitionLevelEncoding(), dataPageV1.getDefinitionLevelEncoding(), - dataPageV1.getValueEncoding()); + dataPageV1.getValueEncoding(), + dataPageV1.getPageIndex()); } DataPageV2 dataPageV2 = (DataPageV2) compressedPage; if (!dataPageV2.isCompressed()) { @@ -141,11 +184,12 @@ public DataPage readPage() dataPageV2.getRepetitionLevels(), dataPageV2.getDefinitionLevels(), dataPageV2.getDataEncoding(), - decompress(dataSourceId, codec, dataPageV2.getSlice(), uncompressedSize), + decompress(dataSourceId, codec, slice, uncompressedSize), dataPageV2.getUncompressedSize(), dataPageV2.getFirstRowIndex(), dataPageV2.getStatistics(), - false); + false, + dataPageV2.getPageIndex()); } catch (IOException e) { throw new RuntimeException("Could not decompress page", e); @@ -162,8 +206,9 @@ public DictionaryPage readDictionaryPage() } try { DictionaryPage compressedDictionaryPage = (DictionaryPage) compressedPages.next(); + Slice slice = decryptSliceIfNeeded(compressedDictionaryPage.getSlice(), dictionaryPageAAD); return new DictionaryPage( - decompress(dataSourceId, codec, compressedDictionaryPage.getSlice(), compressedDictionaryPage.getUncompressedSize()), + decompress(dataSourceId, codec, slice, compressedDictionaryPage.getUncompressedSize()), compressedDictionaryPage.getDictionarySize(), compressedDictionaryPage.getEncoding()); } @@ -199,4 +244,14 @@ private void verifyDictionaryPageRead() { checkArgument(dictionaryAlreadyRead, "Dictionary has to be read first"); } + + private Slice decryptSliceIfNeeded(Slice slice, byte[] aad) + throws IOException + { + if (blockDecryptor == null) { + return slice; + } + byte[] plainText = blockDecryptor.decrypt(slice.getBytes(), aad); + return wrappedBuffer(plainText); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java index 235c1b2d3d76..720d5f16151f 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java @@ -19,15 +19,21 @@ import io.trino.parquet.Page; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.PageHeader; import org.apache.parquet.format.Util; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -51,6 +57,9 @@ public final class ParquetColumnChunkIterator private long valueCount; private int dataPageCount; + private Optional fileDecryptor; + private int columnOrdinal; + private boolean dictionaryWasRead; public ParquetColumnChunkIterator( ParquetDataSourceId dataSourceId, @@ -58,7 +67,9 @@ public ParquetColumnChunkIterator( ColumnDescriptor descriptor, ColumnChunkMetadata metadata, ChunkedInputStream input, - @Nullable OffsetIndex offsetIndex) + @Nullable OffsetIndex offsetIndex, + Optional fileDecryptor, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -66,6 +77,8 @@ public ParquetColumnChunkIterator( this.metadata = requireNonNull(metadata, "metadata is null"); this.input = requireNonNull(input, "input is null"); this.offsetIndex = offsetIndex; + this.fileDecryptor = fileDecryptor; + this.columnOrdinal = columnOrdinal; } @Override @@ -79,8 +92,32 @@ public Page next() { checkState(hasNext(), "No more data left to read in column (%s), metadata (%s), valueCount %s, dataPageCount %s", descriptor, metadata, valueCount, dataPageCount); + byte[] dataPageHeaderAAD = null; + BlockCipher.Decryptor headerBlockDecryptor = null; + InternalColumnDecryptionSetup columnDecryptionSetup = null; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(descriptor.getPath()); + columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + headerBlockDecryptor = columnDecryptionSetup.getMetaDataDecryptor(); + if (null != headerBlockDecryptor) { + dataPageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DataPageHeader, metadata.getRowGroupOrdinal(), columnOrdinal, dataPageCount); + } + } try { - PageHeader pageHeader = readPageHeader(); + byte[] pageHeaderAAD = dataPageHeaderAAD; + if (null != headerBlockDecryptor) { + // Important: this verifies file integrity (makes sure dictionary page had not been removed) + if (!(dictionaryWasRead || !metadata.hasDictionaryPage())) { + pageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DictionaryPageHeader, metadata.getRowGroupOrdinal(), + columnOrdinal, -1); + } + else { + AesCipher.quickUpdatePageAAD(dataPageHeaderAAD, dataPageCount); + } + } + PageHeader pageHeader = readPageHeader(headerBlockDecryptor, pageHeaderAAD); int uncompressedPageSize = pageHeader.getUncompressed_page_size(); int compressedPageSize = pageHeader.getCompressed_page_size(); Page result = null; @@ -90,13 +127,14 @@ public Page next() throw new ParquetCorruptionException(dataSourceId, "Column (%s) has a dictionary page after the first position in column chunk", descriptor); } result = readDictionaryPage(pageHeader, pageHeader.getUncompressed_page_size(), pageHeader.getCompressed_page_size()); + dictionaryWasRead = true; break; case DATA_PAGE: - result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; case DATA_PAGE_V2: - result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; default: @@ -110,10 +148,10 @@ public Page next() } } - private PageHeader readPageHeader() + private PageHeader readPageHeader(BlockCipher.Decryptor headerBlockDecryptor, byte[] pageHeaderAAD) throws IOException { - return Util.readPageHeader(input); + return Util.readPageHeader(input, headerBlockDecryptor, pageHeaderAAD); } private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoFar) @@ -139,7 +177,8 @@ private DataPageV1 readDataPageV1( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeader dataHeaderV1 = pageHeader.getData_page_header(); @@ -151,14 +190,16 @@ private DataPageV1 readDataPageV1( firstRowIndex, getParquetEncoding(Encoding.valueOf(dataHeaderV1.getRepetition_level_encoding().name())), getParquetEncoding(Encoding.valueOf(dataHeaderV1.getDefinition_level_encoding().name())), - getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name()))); + getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name())), + pageIndex); } private DataPageV2 readDataPageV2( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeaderV2 dataHeaderV2 = pageHeader.getData_page_header_v2(); @@ -178,7 +219,8 @@ private DataPageV2 readDataPageV2( fileCreatedBy, Optional.ofNullable(dataHeaderV2.getStatistics()), descriptor.getPrimitiveType()), - dataHeaderV2.isIs_compressed()); + dataHeaderV2.isIs_compressed(), + pageIndex); } private static OptionalLong getFirstRowIndex(int pageIndex, OffsetIndex offsetIndex) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 0ad000ccd420..128375e5a32e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -30,6 +30,9 @@ import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -129,6 +132,7 @@ public class ParquetReader private final Map> codecMetrics; private long columnIndexRowsFiltered = -1; + private final Optional fileDecryptor; public ParquetReader( Optional fileCreatedBy, @@ -140,7 +144,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - Optional writeValidation) + Optional writeValidation, + Optional fileDecryptor) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -156,6 +161,7 @@ public ParquetReader( this.maxBatchSize = options.getMaxReadBlockRowCount(); this.columnReaders = new HashMap<>(); this.maxBytesPerCell = new HashMap<>(); + this.fileDecryptor = fileDecryptor; this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); validateWrite( @@ -264,7 +270,7 @@ public long lastBatchStartRow() return firstRowIndexInGroup + nextRowInGroup - batchSize; } - private int nextBatch() + public int nextBatch() throws IOException { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { @@ -457,9 +463,16 @@ private ColumnChunk readPrimitive(PrimitiveField field) offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath()); } ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup)); - columnReader.setPageReader( - createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy), - Optional.ofNullable(rowRanges)); + if (isEncryptedColumn(fileDecryptor, columnDescriptor)) { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } + else { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } } ColumnChunk columnChunk = columnReader.readPrimitive(); @@ -491,6 +504,19 @@ public Metrics getMetrics() return new Metrics(metrics.buildOrThrow()); } + private ColumnChunkMetadata getColumnChunkMetaData(BlockMetadata blockMetaData, ColumnDescriptor columnDescriptor) + throws IOException + { + for (ColumnChunkMetadata metadata : blockMetaData.columns()) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(metadata)) { + if (metadata.getPath().equals(ColumnPath.get(columnDescriptor.getPath()))) { + return metadata; + } + } + } + throw new ParquetCorruptionException(dataSource.getId(), "Metadata is missing for column: %s", columnDescriptor); + } + private void initializeColumnReaders() { for (PrimitiveField field : primitiveFields) { @@ -612,4 +638,10 @@ private void validateWrite(java.util.function.Predicate throw new ParquetCorruptionException(dataSource.getId(), "Write validation failed: " + messageFormat, args); } } + + private boolean isEncryptedColumn(Optional fileDecryptor, ColumnDescriptor columnDescriptor) + { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + return fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile() && fileDecryptor.get().getColumnSetup(columnPath).isEncrypted(); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 651d86040ef5..9eb40a5665e4 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -237,7 +237,7 @@ public void validate(ParquetDataSource input) checkState(validationBuilder.isPresent(), "validation is not enabled"); ParquetWriteValidation writeValidation = validationBuilder.get().build(); try { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation), Optional.empty()); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page @@ -293,7 +293,8 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new RuntimeException(exception); }, Optional.empty(), - Optional.of(writeValidation)); + Optional.of(writeValidation), + Optional.empty()); } private void recordValidation(Consumer task) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java index 9f7918115838..e6cdd9825e77 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java @@ -225,7 +225,7 @@ public void setup() testData.getColumnNames(), testData.getPages()), new ParquetReaderOptions()); - parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); columnNames = columns.stream() .map(TpchColumn::getColumnName) .collect(toImmutableList()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index febdaccf617b..59280c6de102 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -164,6 +164,7 @@ public static ParquetReader createParquetReader( return new RuntimeException(exception); }, Optional.of(parquetPredicate), + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java index fc47c42d8d82..448ef7dc26a8 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java @@ -105,7 +105,7 @@ public int read() throws IOException { ColumnReader columnReader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false, null, null, -1, -1); columnReader.setPageReader(pageReader, Optional.empty()); int rowsRead = 0; while (rowsRead < dataPositions) { @@ -133,7 +133,8 @@ private DataPage createDataPage(ValuesWriter writer, int valuesCount) OptionalLong.empty(), RLE, RLE, - getParquetEncoding(writer.getEncoding())); + getParquetEncoding(writer.getEncoding()), + 0); } protected static void run(Class> clazz) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 6a3fccb1e281..37dde42f5e57 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -564,7 +564,11 @@ else if (dictionaryEncoding == DictionaryEncoding.MIXED) { UNCOMPRESSED, inputPages.iterator(), dictionaryEncoding == DictionaryEncoding.ALL || (dictionaryEncoding == DictionaryEncoding.MIXED && testingPages.size() == 1), - false); + false, + null, + null, + -1, + -1); } private static List createDataPages(List testingPages, ValuesWriter encoder, int maxDef, boolean required) @@ -599,7 +603,8 @@ private static DataPage createDataPage(TestingPage testingPage, ValuesWriter enc valueCount * 4, OptionalLong.of(testingPage.pageRowRange().start()), null, - false); + false, + 0); encoder.reset(); return dataPage; } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 445b61268c33..8b8fe067c88a 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -660,7 +660,8 @@ protected static DataPage createDataPage( OptionalLong.empty(), getParquetEncoding(repetitionWriter.getEncoding()), getParquetEncoding(definitionWriter.getEncoding()), - encoding); + encoding, + 0); } return new DataPageV2( valueCount, @@ -673,7 +674,8 @@ protected static DataPage createDataPage( definitionBytes.length + repetitionBytes.length + valueBytes.length, OptionalLong.empty(), null, - false); + false, + 0); } protected static PageReader getPageReaderMock(List dataPages, @Nullable DictionaryPage dictionaryPage) @@ -699,7 +701,7 @@ protected static PageReader getPageReaderMock(List dataPages, @Nullabl return ((DataPageV2) page).getDataEncoding(); }) .allMatch(encoding -> encoding == PLAIN_DICTIONARY || encoding == RLE_DICTIONARY), - hasNoNulls); + hasNoNulls, null, null, -1, -1); } private DataPage createDataPage(DataPageVersion version, ParquetEncoding encoding, ValuesWriter writer, int valueCount) @@ -713,7 +715,7 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin { Slice slice = Slices.wrappedBuffer(writer.getBytes().toByteArray()); if (version == V1) { - return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding); + return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding, 0); } return new DataPageV2( valueCount, @@ -726,7 +728,8 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin slice.length(), firstRowIndex, null, - false); + false, + 0); } private static ValuesWriter getLevelsWriter(int maxLevel, int valueCount) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java new file mode 100644 index 000000000000..ac6981666b57 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.crypto.ColumnEncryptionProperties; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.FileEncryptionProperties; +import io.trino.parquet.crypto.ParquetCipher; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EncDecPropertiesHelper +{ + private EncDecPropertiesHelper() + { + } + + private static class DecryptionKeyRetrieverMock + implements DecryptionKeyRetriever + { + private final Map keyMap = new HashMap<>(); + + public DecryptionKeyRetrieverMock putKey(String keyId, byte[] keyBytes) + { + keyMap.put(keyId, keyBytes); + return this; + } + + @Override + public byte[] getKey(byte[] keyMetaData) + { + String keyId = new String(keyMetaData, StandardCharsets.UTF_8); + return keyMap.get(keyId); + } + } + + private static final byte[] FOOTER_KEY = {0x01, 0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}; + private static final byte[] FOOTER_KEY_METADATA = "footkey".getBytes(StandardCharsets.UTF_8); + private static final byte[] COL_KEY = {0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11}; + private static final byte[] COL_KEY_METADATA = "col".getBytes(StandardCharsets.UTF_8); + + public static FileDecryptionProperties getFileDecryptionProperties() + throws IOException + { + DecryptionKeyRetrieverMock keyRetriever = new DecryptionKeyRetrieverMock(); + keyRetriever.putKey("footkey", FOOTER_KEY); + keyRetriever.putKey("col", COL_KEY); + return FileDecryptionProperties.builder().withPlaintextFilesAllowed().withKeyRetriever(keyRetriever).build(); + } + + public static FileEncryptionProperties getFileEncryptionProperties(List encryptColumns, ParquetCipher cipher, Boolean encryptFooter) + { + if (encryptColumns.size() == 0) { + return null; + } + + Map columnPropertyMap = new HashMap<>(); + for (String encryptColumn : encryptColumns) { + ColumnPath columnPath = ColumnPath.fromDotString(encryptColumn); + ColumnEncryptionProperties columnEncryptionProperties = ColumnEncryptionProperties.builder(columnPath) + .withKey(COL_KEY) + .withKeyMetaData(COL_KEY_METADATA) + .build(); + columnPropertyMap.put(columnPath, columnEncryptionProperties); + } + + FileEncryptionProperties.Builder encryptionPropertiesBuilder = + FileEncryptionProperties.builder(FOOTER_KEY) + .withFooterKeyMetadata(FOOTER_KEY_METADATA) + .withAlgorithm(cipher) + .withEncryptedColumns(columnPropertyMap); + + if (!encryptFooter) { + encryptionPropertiesBuilder.withPlaintextFooter(); + } + + return encryptionPropertiesBuilder.build(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java new file mode 100644 index 000000000000..d7677525ef13 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import org.apache.parquet.example.data.simple.SimpleGroup; + +public class EncryptionTestFile +{ + private final String fileName; + private final SimpleGroup[] fileContent; + + public EncryptionTestFile(String fileName, SimpleGroup[] fileContent) + { + this.fileName = fileName; + this.fileContent = fileContent; + } + + public String getFileName() + { + return this.fileName; + } + + public SimpleGroup[] getFileContent() + { + return this.fileContent; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java new file mode 100644 index 000000000000..dd46ccb689b0 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class MockInputStreamTail +{ + public static final int MAX_SUPPORTED_PADDING_BYTES = 64; + private static final int MAXIMUM_READ_LENGTH = Integer.MAX_VALUE - (MAX_SUPPORTED_PADDING_BYTES + 1); + + private final Slice tailSlice; + private final long fileSize; + + private MockInputStreamTail(long fileSize, Slice tailSlice) + { + this.tailSlice = requireNonNull(tailSlice, "tailSlice is null"); + this.fileSize = fileSize; + checkArgument(fileSize >= 0, "fileSize is negative: %s", fileSize); + checkArgument(tailSlice.length() <= fileSize, "length (%s) is greater than fileSize (%s)", tailSlice.length(), fileSize); + } + + public static MockInputStreamTail readTail(String path, long paddedFileSize, FSDataInputStream inputStream, int length) + throws IOException + { + checkArgument(length >= 0, "length is negative: %s", length); + checkArgument(length <= MAXIMUM_READ_LENGTH, "length (%s) exceeds maximum (%s)", length, MAXIMUM_READ_LENGTH); + long readSize = min(paddedFileSize, (length + MAX_SUPPORTED_PADDING_BYTES)); + long position = paddedFileSize - readSize; + // Actual read will be 1 byte larger to ensure we encounter an EOF where expected + byte[] buffer = new byte[toIntExact(readSize + 1)]; + int bytesRead = 0; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + while (bytesRead < buffer.length) { + int n = inputStream.read(buffer, bytesRead, buffer.length - bytesRead); + if (n < 0) { + break; + } + bytesRead += n; + } + } + finally { + inputStream.seek(startPos); + } + if (bytesRead > readSize) { + throw rejectInvalidFileSize(path, paddedFileSize); + } + return new MockInputStreamTail(position + bytesRead, Slices.wrappedBuffer(buffer, max(0, bytesRead - length), min(bytesRead, length))); + } + + public static long readTailForFileSize(String path, long paddedFileSize, FSDataInputStream inputStream) + throws IOException + { + long position = max(paddedFileSize - MAX_SUPPORTED_PADDING_BYTES, 0); + long maxEOFAt = paddedFileSize + 1; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + int c; + while (position < maxEOFAt) { + c = inputStream.read(); + if (c < 0) { + return position; + } + position++; + } + throw rejectInvalidFileSize(path, paddedFileSize); + } + finally { + inputStream.seek(startPos); + } + } + + private static IOException rejectInvalidFileSize(String path, long reportedSize) + throws IOException + { + throw new IOException(format("Incorrect file size (%s) for file (end of stream not reached): %s", reportedSize, path)); + } + + public long getFileSize() + { + return fileSize; + } + + public Slice getTailSlice() + { + return tailSlice; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java new file mode 100644 index 000000000000..2652e2da3301 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java @@ -0,0 +1,335 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.units.DataSize; +import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.parquet.ChunkReader; +import io.trino.parquet.DiskRange; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.Math.toIntExact; +import static java.util.Comparator.comparingLong; +import static java.util.Objects.requireNonNull; + +public class MockParquetDataSource + implements ParquetDataSource +{ + private final ParquetDataSourceId id; + private final long estimatedSize; + private final FSDataInputStream inputStream; + private long readTimeNanos; + private long readBytes; + private final ParquetReaderOptions options; + + public MockParquetDataSource( + ParquetDataSourceId id, + long estimatedSize, + FSDataInputStream inputStream, + ParquetReaderOptions options) + { + this.id = requireNonNull(id, "id is null"); + this.estimatedSize = estimatedSize; + this.inputStream = inputStream; + this.options = requireNonNull(options, "options is null"); + } + + @Override + public ParquetDataSourceId getId() + { + return id; + } + + @Override + public final long getReadBytes() + { + return readBytes; + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos; + } + + @Override + public final long getEstimatedSize() + { + return estimatedSize; + } + + @Override + public void close() + throws IOException + { + inputStream.close(); + } + + @Override + public Slice readTail(int length) + { + long start = System.nanoTime(); + Slice tailSlice; + try { + // Handle potentially imprecise file lengths by reading the footer + MockInputStreamTail fileTail = MockInputStreamTail.readTail(getId().toString(), getEstimatedSize(), inputStream, length); + tailSlice = fileTail.getTailSlice(); + } + catch (IOException e) { + throw new RuntimeException("Error reading tail from %s with length"); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + readBytes += tailSlice.length(); + return tailSlice; + } + + @Override + public final Slice readFully(long position, int length) + { + byte[] buffer = new byte[length]; + readFully(position, buffer, 0, length); + return Slices.wrappedBuffer(buffer); + } + + @Override + public final Map planRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableMap.of(); + } + + return planChunksRead(diskRanges, memoryContext).asMap() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> new ChunkedInputStream(entry.getValue()))); + } + + public ListMultimap planChunksRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + // + // Note: this code does not use the stream APIs to avoid any extra object allocation + // + + // split disk ranges into "big" and "small" + ImmutableListMultimap.Builder smallRangesBuilder = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder largeRangesBuilder = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + if (entry.getValue().getLength() <= options.getMaxBufferSize().toBytes()) { + smallRangesBuilder.put(entry); + } + else { + largeRangesBuilder.put(entry); + } + } + ListMultimap smallRanges = smallRangesBuilder.build(); + ListMultimap largeRanges = largeRangesBuilder.build(); + + // read ranges + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + slices.putAll(readSmallDiskRanges(smallRanges)); + slices.putAll(readLargeDiskRanges(largeRanges)); + + return slices.build(); + } + + private void readFully(long position, byte[] buffer, int bufferOffset, int bufferLength) + { + readBytes += bufferLength; + + long start = System.nanoTime(); + try { + inputStream.readFully(position, buffer, bufferOffset, bufferLength); + } + catch (Exception e) { + throw new RuntimeException("Error reading from %s " + id + " at position " + position); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + } + + private ListMultimap readSmallDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + Iterable mergedRanges = mergeAdjacentDiskRanges(diskRanges.values(), options.getMaxMergeDistance(), options.getMaxBufferSize()); + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (DiskRange mergedRange : mergedRanges) { + ReferenceCountedReader mergedRangeLoader = new ReferenceCountedReader(mergedRange); + + for (Map.Entry diskRangeEntry : diskRanges.entries()) { + DiskRange diskRange = diskRangeEntry.getValue(); + if (mergedRange.contains(diskRange)) { + mergedRangeLoader.addReference(); + + slices.put(diskRangeEntry.getKey(), new ChunkReader() + { + @Override + public Slice read() + { + int offset = toIntExact(diskRange.getOffset() - mergedRange.getOffset()); + return mergedRangeLoader.read().slice(offset, Long.valueOf(diskRange.getLength()).intValue()); + } + + @Override + public void free() + { + mergedRangeLoader.free(); + } + + @Override + public long getDiskOffset() + { + return diskRange.getOffset(); + } + }); + } + } + + mergedRangeLoader.free(); + } + + ListMultimap sliceStreams = slices.build(); + verify(sliceStreams.keySet().equals(diskRanges.keySet())); + return sliceStreams; + } + + private ListMultimap readLargeDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + slices.put(entry.getKey(), new ReferenceCountedReader(entry.getValue())); + } + return slices.build(); + } + + private static List mergeAdjacentDiskRanges(Collection diskRanges, DataSize maxMergeDistance, DataSize maxReadSize) + { + // sort ranges by start offset + List ranges = new ArrayList<>(diskRanges); + ranges.sort(comparingLong(DiskRange::getOffset)); + + long maxReadSizeBytes = maxReadSize.toBytes(); + long maxMergeDistanceBytes = maxMergeDistance.toBytes(); + + // merge overlapping ranges + ImmutableList.Builder result = ImmutableList.builder(); + DiskRange last = ranges.get(0); + for (int i = 1; i < ranges.size(); i++) { + DiskRange current = ranges.get(i); + DiskRange merged = null; + boolean blockTooLong = false; + try { + merged = last.span(current); + } + catch (ArithmeticException e) { + blockTooLong = true; + } + if (!blockTooLong && merged.getLength() <= maxReadSizeBytes && last.getEnd() + maxMergeDistanceBytes >= current.getOffset()) { + last = merged; + } + else { + result.add(last); + last = current; + } + } + result.add(last); + + return result.build(); + } + + private class ReferenceCountedReader + implements ChunkReader + { + private final DiskRange range; + private Slice data; + private int referenceCount = 1; + + public ReferenceCountedReader(DiskRange range) + { + this.range = range; + } + + public void addReference() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + referenceCount++; + } + + @Override + public Slice read() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + + if (data == null) { + byte[] buffer = new byte[Long.valueOf(range.getLength()).intValue()]; + readFully(range.getOffset(), buffer, 0, buffer.length); + data = Slices.wrappedBuffer(buffer); + } + + return data; + } + + @Override + public void free() + { + checkState(referenceCount > 0, "Reference count is already 0"); + + referenceCount--; + if (referenceCount == 0) { + data = null; + } + } + + @Override + public long getDiskOffset() + { + return range.getOffset(); + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java index d42725e5acb2..7f448bdbed2d 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java @@ -50,7 +50,7 @@ public void testReadFloatDouble() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("byte_stream_split_float_and_double.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); readAndCompare(reader, getExpectedValues()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..c178d5be0261 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableSet; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.HiddenColumnException; +import io.trino.parquet.metadata.ColumnChunkMetadata; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Types; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.parquet.column.Encoding.PLAIN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestHiddenColumnChunkMetaData +{ + @Test + public void testIsHiddenColumn() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isTrue(); + } + + @Test + public void testIsNotHiddenColumn() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isFalse(); + } + + @Test(expectedExceptions = HiddenColumnException.class) + public void testHiddenColumnException() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + column.getStatistics(); + } + + @Test + public void testNoHiddenColumnException() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + column.getStatistics(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index aabb734e5b0c..49e4fc2f9d80 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -112,7 +112,7 @@ public void testNanosOutsideDayRange() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("int96_timestamps_nanos_outside_day_range.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); @@ -166,11 +166,12 @@ private void testVariousTimestamps(TimestampType type) slice.length(), OptionalLong.empty(), null, - false); + false, + 0); // Read and assert ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(DateTimeZone.UTC, new ParquetReaderOptions()); ColumnReader reader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false, null, null, -1, -1); reader.setPageReader(pageReader, Optional.empty()); reader.prepareNextRead(valueCount); Block block = reader.readPrimitive().getBlock(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java index 102e2b4fc01b..a94ff78cf8f2 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetTypeUtils; +import io.trino.parquet.crypto.InternalFileDecryptor; import io.trino.parquet.metadata.ColumnChunkMetadata; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.EncodingStats; @@ -183,7 +184,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP out.write(compressedDataPage); byte[] bytes = out.toByteArray(); - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); DictionaryPage uncompressedDictionaryPage = pageReader.readDictionaryPage(); assertThat(uncompressedDictionaryPage.getDictionarySize()).isEqualTo(dictionaryPageHeader.getDictionary_page_header().getNum_values()); assertEncodingEquals(uncompressedDictionaryPage.getEncoding(), dictionaryPageHeader.getDictionary_page_header().getEncoding()); @@ -193,7 +194,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP assertPages(compressionCodec, totalValueCount, 3, pageHeader, compressedDataPage, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); // only dictionary - pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize)))); + pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize))), null, -1); assertThatThrownBy(pageReader::readDictionaryPage) .isInstanceOf(IllegalStateException.class) .hasMessageStartingWith("No more data left to read"); @@ -236,7 +237,7 @@ public void dictionaryPageNotFirst() int totalValueCount = valueCount * 2; // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNull(); assertThat(pageReader.readPage()).isNotNull(); assertThatThrownBy(pageReader::readPage) @@ -270,7 +271,7 @@ public void unusedDictionaryPage() byte[] bytes = out.toByteArray(); // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNotNull(); assertThat(pageReader.readPage()).isNotNull(); assertThat(pageReader.readPage()).isNull(); @@ -298,7 +299,7 @@ private static void assertPages( List slices) throws IOException { - PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices); + PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices, null, -1); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); assertThat(dictionaryPage != null).isEqualTo(hasDictionary); @@ -383,7 +384,7 @@ private static byte[] compress(CompressionCodec compressionCodec, byte[] bytes, throw new IllegalArgumentException("unsupported compression code " + compressionCodec); } - private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices) + private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices, InternalFileDecryptor fileDecryptor, int rowGroupOrdinal) throws IOException { EncodingStats.Builder encodingStats = new EncodingStats.Builder(); @@ -409,7 +410,8 @@ private static PageReader createPageReader(int valueCount, CompressionCodec comp columnChunkMetaData, new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, INT32, ""), 0, 0), null, - Optional.empty()); + Optional.empty(), + Optional.ofNullable(fileDecryptor)); } private static void assertDataPageEquals(PageHeader pageHeader, byte[] dataPage, byte[] compressedDataPage, DataPage decompressedPage) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java index 2ef475a7644f..0c4f3011dbb1 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java @@ -79,7 +79,7 @@ public void testColumnReaderMemoryUsage() columnNames, generateInputPages(types, 100, 5)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThan(1); // Verify file has only non-dictionary encodings as dictionary memory usage is already tested in TestFlatColumnReader#testMemoryUsage parquetMetadata.getBlocks().forEach(block -> { @@ -132,7 +132,7 @@ public void testEmptyRowRangesWithColumnIndex() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("lineitem_sorted_by_shipdate/data.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(2); // The predicate and the file are prepared so that page indexes will result in non-overlapping row ranges and eliminate the entire first row group // while the second row group still has to be read @@ -193,7 +193,7 @@ private void testReadingOldParquetFiles(File file, List columnNames, Typ file, new ParquetReaderOptions()); ConnectorSession session = TestingConnectorSession.builder().build(); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), ImmutableList.of(columnType), columnNames)) { Page page = reader.nextPage(); Iterator> expected = expectedValues.iterator(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index 390608f445a9..99ae226bca08 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -60,7 +60,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("time_millis_int32.snappy.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java index a3efb46b6d71..8222899ab90b 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java @@ -137,8 +137,9 @@ private static PageReader getSimplePageReaderMock(ParquetEncoding encoding) OptionalLong.empty(), encoding, encoding, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } private static PageReader getNullOnlyPageReaderMock() @@ -154,7 +155,8 @@ private static PageReader getNullOnlyPageReaderMock() OptionalLong.empty(), RLE, RLE, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index 846080c3297a..717474419d11 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -127,7 +127,7 @@ public void testWrittenPageSize() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -141,6 +141,7 @@ public void testWrittenPageSize() chunkMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -176,7 +177,7 @@ public void testWrittenPageValueCount() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -194,6 +195,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -213,6 +215,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columnb"}, new PrimitiveType(REQUIRED, INT64, "columnb"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -256,8 +259,7 @@ public void testLargeStringTruncation() columnNames, ImmutableList.of(new Page(2, blockA, blockB))), new ParquetReaderOptions()); - - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); BlockMetadata blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetadata chunkMetaData = blockMetaData.columns().get(0); @@ -290,7 +292,7 @@ public void testColumnReordering() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(10); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { // Verify that the columns are stored in the same order as the metadata @@ -347,7 +349,7 @@ public void testDictionaryPageOffset() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(1); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { ColumnChunkMetadata chunkMetaData = getOnlyElement(blockMetaData.columns()); @@ -393,7 +395,7 @@ public void testWriteBloomFilters(Type type, List> data) generateInputPages(types, 100, data)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); // Check that bloom filters are right after each other int bloomFilterSize = Integer.highestOneBit(BlockSplitBloomFilter.optimalNumOfBits(BLOOM_FILTER_EXPECTED_ENTRIES, DEFAULT_BLOOM_FILTER_FPP) / 8) << 1; for (BlockMetadata block : parquetMetadata.getBlocks()) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index 5fe764a72756..eb0a41cd8108 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -362,7 +362,7 @@ private Slice writeMergeResult(Slice path, FileDeletion deletion) TrinoInputFile inputFile = fileSystem.newInputFile(Location.of(path.toStringUtf8())); try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, parquetReaderOptions, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); long rowCount = parquetMetadata.getBlocks().stream().map(BlockMetadata::rowCount).mapToLong(Long::longValue).sum(); RoaringBitmapArray rowsRetained = new RoaringBitmapArray(); rowsRetained.addRange(0, rowCount - 1); @@ -637,7 +637,8 @@ private ReaderPageSource createParquetPageSource(Location path) new ParquetReaderOptions().withBloomFilter(false), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + null); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index f08ecc84f839..c552b1944e2c 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -254,7 +254,8 @@ public ConnectorPageSource createPageSource( options, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(split.getFileSize())); + OptionalLong.of(split.getFileSize()), + null); Optional projectionsAdapter = pageSource.getReaderColumns().map(readerColumns -> new ReaderProjectionsAdapter( @@ -306,7 +307,7 @@ private PositionDeleteFilter readDeletes( public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index 8f686205e239..5330c6edd100 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -184,7 +184,7 @@ public DataFileInfo getDataFileInfo() { Location path = rootTableLocation.appendPath(relativeFilePath); FileMetaData fileMetaData = fileWriter.getFileMetadata(); - ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString())); + ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString()), Optional.empty(), false); return new DataFileInfo( relativeFilePath, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index 7f5d4b8a88c6..3c57de2ef2f3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -205,7 +205,8 @@ private static DeltaLakePageSource createDeltaLakePageSource( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.empty()); + OptionalLong.of(split.fileSize()), + null); verify(pageSource.getReaderColumns().isEmpty(), "Unexpected reader columns: %s", pageSource.getReaderColumns().orElse(null)); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 04673aeab8ea..985cc433aaea 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -231,7 +231,8 @@ public CheckpointEntryIterator( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + Optional.empty()); this.pageSource = (ParquetPageSource) pageSource.get(); try { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index 70cdce9c5e4f..7f9050feb512 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -329,7 +329,7 @@ private void testOptimizeWithColumnMappingMode(String columnMappingMode) TrinoInputFile inputFile = new LocalInputFile(tableLocation.resolve(addFileEntry.getPath()).toFile()); ParquetMetadata parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); PrimitiveType physicalType = getOnlyElement(fileMetaData.getSchema().getColumns().iterator()).getPrimitiveType(); assertThat(physicalType.getName()).isEqualTo(physicalName); diff --git a/plugin/trino-geospatial/pom.xml b/plugin/trino-geospatial/pom.xml index 6d975cb4232a..1636b6f25102 100644 --- a/plugin/trino-geospatial/pom.xml +++ b/plugin/trino-geospatial/pom.xml @@ -230,4 +230,20 @@ test + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + mozilla/.* + about.html + mime.types + iceberg-build.properties + + + + + diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 07e7f860108a..9a5dcf292204 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -379,6 +379,21 @@ runtime + + + org.codehaus.jackson + jackson-core-asl + 1.9.13 + runtime + + + + org.codehaus.jackson + jackson-mapper-asl + 1.9.13 + runtime + + org.jetbrains annotations @@ -661,6 +676,17 @@ + + org.basepom.maven + duplicate-finder-maven-plugin + + + about.html + iceberg-build.properties + mozilla/public-suffix-list.txt + + + diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java index 325f2292da83..61bf0717f3e4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java @@ -195,9 +195,7 @@ public static Optional createHivePageSource( Optional bucketAdaptation = createBucketAdaptation(bucketConversion, tableBucketNumber, regularAndInterimColumnMappings); Optional bucketValidator = createBucketValidator(path, bucketValidation, tableBucketNumber, regularAndInterimColumnMappings); - CoercionContext coercionContext = new CoercionContext(getTimestampPrecision(session), extractHiveStorageFormat(getDeserializerClassName(schema))); - for (HivePageSourceFactory pageSourceFactory : pageSourceFactories) { List desiredColumns = toColumnHandles(regularAndInterimColumnMappings, typeManager, coercionContext); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index dcd88f2523d6..041bda8bd44b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -24,12 +24,14 @@ import io.trino.memory.context.AggregatedMemoryContext; import io.trino.metastore.HiveType; import io.trino.parquet.Column; +import io.trino.parquet.EncryptionUtils; import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.InternalFileDecryptor; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -178,7 +180,7 @@ public Optional createPageSource( TrinoFileSystem fileSystem = fileSystemFactory.create(session); TrinoInputFile inputFile = fileSystem.newInputFile(path, estimatedFileSize, Instant.ofEpochMilli(fileModifiedTime)); - + final Optional internalFileDecryptor = EncryptionUtils.createDecryptor(options, path, fileSystem); return Optional.of(createPageSource( inputFile, start, @@ -197,7 +199,8 @@ public Optional createPageSource( .withVectorizedDecodingEnabled(isParquetVectorizedDecodingEnabled(session)), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(estimatedFileSize))); + OptionalLong.of(estimatedFileSize), + internalFileDecryptor)); } /** @@ -215,7 +218,8 @@ public static ReaderPageSource createPageSource( ParquetReaderOptions options, Optional parquetWriteValidation, int domainCompactionThreshold, - OptionalLong estimatedFileSize) + OptionalLong estimatedFileSize, + Optional internalFileDecryptor) { MessageType fileSchema; MessageType requestedSchema; @@ -224,8 +228,7 @@ public static ReaderPageSource createPageSource( try { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, estimatedFileSize, options, memoryContext, stats); - - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation, internalFileDecryptor); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); fileSchema = fileMetaData.getSchema(); @@ -286,7 +289,8 @@ public static ReaderPageSource createPageSource( // We avoid using disjuncts of parquetPredicate for page pruning in ParquetReader as currently column indexes // are not present in the Parquet files which are read with disjunct predicates. parquetPredicates.size() == 1 ? Optional.of(parquetPredicates.get(0)) : Optional.empty(), - parquetWriteValidation); + parquetWriteValidation, + internalFileDecryptor); ConnectorPageSource parquetPageSource = createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); return new ReaderPageSource(parquetPageSource, readerProjections); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java index 45b0ad2fade6..5fdc7915d19f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java @@ -160,6 +160,103 @@ public boolean isVectorizedDecodingEnabled() return options.isVectorizedDecodingEnabled(); } + @Config("parquet.crypto-factory-class") + @ConfigDescription("Crypto factory class to encrypt or decrypt parquet files") + public ParquetReaderConfig setCryptoFactoryClass(String cryptoFactoryClass) + { + options = options.withEncryptionOption(options.encryptionOptions().withCryptoFactoryClass(cryptoFactoryClass)); + return this; + } + + public String getCryptoFactoryClass() + { + return options.getCryptoFactoryClass(); + } + + @Config("parquet.encryption-kms-client-class") + @ConfigDescription("Class implementing the KmsClient interface. KMS stands for “key management service") + public ParquetReaderConfig setEncryptionKmsClientClass(String encryptionKmsClientClass) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionKmsClientClass(encryptionKmsClientClass)); + return this; + } + + public String getEncryptionKmsClientClass() + { + return options.getEncryptionKmsClientClass(); + } + + @Config("parquet.encryption-kms-instance-id") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKmsInstanceId(String encryptionKmsInstanceId) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionKmsInstanceId(encryptionKmsInstanceId)); + return this; + } + + public String getEncryptionKmsInstanceId() + { + return options.getEncryptionKmsInstanceId(); + } + + @Config("parquet.encryption-kms-instance-url") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKmsInstanceUrl(String encryptionKmsInstanceUrl) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionKmsInstanceUrl(encryptionKmsInstanceUrl)); + return this; + } + + public String getEncryptionKmsInstanceUrl() + { + return options.getEncryptionKmsInstanceUrl(); + } + + @Config("parquet.encryption-key-access-token") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKeyAccessToken(String encryptionKeyAccessToken) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionKeyAccessToken(encryptionKeyAccessToken)); + return this; + } + + public String getEncryptionKeyAccessToken() + { + return options.getEncryptionKeyAccessToken(); + } + + @Config("parquet.encryption-cache-lifetime-seconds") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionCacheLifetimeSeconds(Long encryptionCacheLifetimeSeconds) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionCacheLifetimeSeconds(encryptionCacheLifetimeSeconds)); + return this; + } + + public Long getEncryptionCacheLifetimeSeconds() + { + return options.getEncryptionCacheLifetimeSeconds(); + } + + public String getEncryptionMasterKeyFile() + { + return options.getEncryptionKeyFile(); + } + + @Config("parquet.encryption-master-key-file") + @ConfigDescription("the path to master key file") + public ParquetReaderConfig setEncryptionMasterKeyFile(String keyFile) + { + options = options.withEncryptionOption( + options.encryptionOptions().withEncryptionKeyFile(keyFile)); + return this; + } + public ParquetReaderOptions toParquetReaderOptions() { return options; diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java index 670d3ac0e259..398ff26937ea 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java @@ -20,6 +20,9 @@ import io.airlift.units.Duration; import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Map; import java.util.TimeZone; import java.util.concurrent.TimeUnit; @@ -123,7 +126,11 @@ public void testDefaults() @Test public void testExplicitPropertyMappings() + throws IOException { + Path resource1 = Files.createTempFile(null, null); + Path resource2 = Files.createTempFile(null, null); + Map properties = ImmutableMap.builder() .put("hive.single-statement-writes", "true") .put("hive.max-split-size", "256MB") diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java index 10e1e3378366..cb8f7e592230 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java @@ -308,7 +308,7 @@ private static BloomFilterStore generateBloomFilterStore(ParquetTester.TempFile TrinoInputFile inputFile = new LocalInputFile(tempFile.getFile()); TrinoParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ColumnChunkMetadata columnChunkMetaData = getOnlyElement(getOnlyElement(parquetMetadata.getBlocks()).columns()); return new BloomFilterStore(dataSource, getOnlyElement(parquetMetadata.getBlocks()), Set.of(columnChunkMetaData.getPath())); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java index 6d980a2483ad..e8e060ace4ee 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java @@ -39,7 +39,14 @@ public void testDefaults() .setUseColumnIndex(true) .setUseBloomFilter(true) .setSmallFileThreshold(DataSize.of(3, MEGABYTE)) - .setVectorizedDecodingEnabled(true)); + .setVectorizedDecodingEnabled(true) + .setCryptoFactoryClass(null) + .setEncryptionKmsClientClass(null) + .setEncryptionKmsInstanceId(null) + .setEncryptionKmsInstanceUrl(null) + .setEncryptionCacheLifetimeSeconds(600L) + .setEncryptionKeyAccessToken("DEFAULT") + .setEncryptionMasterKeyFile(null)); } @Test @@ -55,6 +62,13 @@ public void testExplicitPropertyMappings() .put("parquet.use-bloom-filter", "false") .put("parquet.small-file-threshold", "1kB") .put("parquet.experimental.vectorized-decoding.enabled", "false") + .put("parquet.crypto-factory-class", "test") + .put("parquet.encryption-cache-lifetime-seconds", "100") + .put("parquet.encryption-key-access-token", "testToken") + .put("parquet.encryption-kms-client-class", "testKmsClient") + .put("parquet.encryption-kms-instance-id", "testInstanceId") + .put("parquet.encryption-kms-instance-url", "testKmsUrl") + .put("parquet.encryption-master-key-file", "testKeyFile") .buildOrThrow(); ParquetReaderConfig expected = new ParquetReaderConfig() @@ -66,7 +80,14 @@ public void testExplicitPropertyMappings() .setUseColumnIndex(false) .setUseBloomFilter(false) .setSmallFileThreshold(DataSize.of(1, KILOBYTE)) - .setVectorizedDecodingEnabled(false); + .setVectorizedDecodingEnabled(false) + .setCryptoFactoryClass("test") + .setEncryptionKmsClientClass("testKmsClient") + .setEncryptionKmsInstanceId("testInstanceId") + .setEncryptionKmsInstanceUrl("testKmsUrl") + .setEncryptionCacheLifetimeSeconds(100L) + .setEncryptionKeyAccessToken("testToken") + .setEncryptionMasterKeyFile("testKeyFile"); assertFullMapping(properties, expected); } diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java index 15129f4ee17d..857ae4b672ef 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java @@ -198,7 +198,7 @@ private static ConnectorPageSource createPageSource( try { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, OptionalLong.of(hudiSplit.getFileSize()), options, memoryContext, dataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); @@ -244,6 +244,7 @@ private static ConnectorPageSource createPageSource( options, exception -> handleException(dataSourceId, exception), Optional.of(parquetPredicate), + Optional.empty(), Optional.empty()); return createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index d937f5c57133..9984e4092983 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -912,7 +912,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ParquetDataSource dataSource = null; try { dataSource = createDataSource(inputFile, OptionalLong.of(fileSize), options, memoryContext, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); if (nameMapping.isPresent() && !ParquetSchemaUtil.hasIds(fileSchema)) { @@ -1023,6 +1023,7 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { options, exception -> handleException(dataSourceId, exception), Optional.empty(), + Optional.empty(), Optional.empty()); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java index 7f0716b66188..b6028856dc20 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java @@ -83,7 +83,7 @@ public FileMetrics getFileMetrics() { ParquetMetadata parquetMetadata; try { - parquetMetadata = createParquetMetadata(parquetFileWriter.getFileMetadata(), new ParquetDataSourceId(location.toString())); + parquetMetadata = createParquetMetadata(parquetFileWriter.getFileMetadata(), new ParquetDataSourceId(location.toString()), Optional.empty(), false); } catch (IOException e) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Error creating metadata for Parquet file %s", location), e); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java index caaead5fc3f4..eadef2c1096b 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java @@ -400,7 +400,7 @@ private static Metrics loadMetrics(TrinoInputFile file, HiveStorageFormat storag private static Metrics parquetMetrics(TrinoInputFile file, MetricsConfig metricsConfig, NameMapping nameMapping) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(file, new ParquetReaderOptions(), new FileFormatDataSourceStats())) { - ParquetMetadata metadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata metadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); return ParquetUtil.footerMetrics(metadata, Stream.empty(), metricsConfig, nameMapping); } catch (IOException e) { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java index 24c18d4162ec..ce5b2e6fd9c6 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java @@ -131,7 +131,7 @@ public static boolean checkParquetFileSorting(TrinoInputFile inputFile, String s try { parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); } catch (IOException e) { throw new UncheckedIOException(e);
+ * Key material is kept in a flat json object, with the following fields: + * 1. "keyMaterialType" - a String, with the type of key material. In the current version, only one value is allowed - "PKMT1" (stands + * for "parquet key management tools, version 1"). For external key material storage, this field is written in both "key metadata" and + * "key material" jsons. For internal key material storage, this field is written only once in the common json. + * 2. "isFooterKey" - a boolean. If true, means that the material belongs to a file footer key, and keeps additional information (such as + * KMS instance ID and URL). If false, means that the material belongs to a column key. + * 3. "kmsInstanceID" - a String, with the KMS Instance ID. Written only in footer key material. + * 4. "kmsInstanceURL" - a String, with the KMS Instance URL. Written only in footer key material. + * 5. "masterKeyID" - a String, with the ID of the master key used to generate the material. + * 6. "wrappedDEK" - a String, with the wrapped DEK (base64 encoding). + * 7. "doubleWrapping" - a boolean. If true, means that the material was generated in double wrapping mode. + * If false - in single wrapping mode. + * 8. "keyEncryptionKeyID" - a String, with the ID of the KEK used to generate the material. Written only in double wrapping mode. + * 9. "wrappedKEK" - a String, with the wrapped KEK (base64 encoding). Written only in double wrapping mode. + */ +public class KeyMaterial +{ + static final String KEY_MATERIAL_TYPE_FIELD = "keyMaterialType"; + static final String KEY_MATERIAL_TYPE1 = "PKMT1"; + + static final String FOOTER_KEY_ID_IN_FILE = "footerKey"; + static final String COLUMN_KEY_ID_IN_FILE_PREFIX = "columnKey"; + + private static final String IS_FOOTER_KEY_FIELD = "isFooterKey"; + private static final String DOUBLE_WRAPPING_FIELD = "doubleWrapping"; + private static final String KMS_INSTANCE_ID_FIELD = "kmsInstanceID"; + private static final String KMS_INSTANCE_URL_FIELD = "kmsInstanceURL"; + private static final String MASTER_KEY_ID_FIELD = "masterKeyID"; + private static final String WRAPPED_DEK_FIELD = "wrappedDEK"; + private static final String KEK_ID_FIELD = "keyEncryptionKeyID"; + private static final String WRAPPED_KEK_FIELD = "wrappedKEK"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final boolean isFooterKey; + private final String kmsInstanceID; + private final String kmsInstanceURL; + private final String masterKeyID; + private final boolean isDoubleWrapped; + private final String kekID; + private final String encodedWrappedKEK; + private final String encodedWrappedDEK; + + private KeyMaterial( + boolean isFooterKey, + String kmsInstanceID, + String kmsInstanceURL, + String masterKeyID, + boolean isDoubleWrapped, + String kekID, + String encodedWrappedKEK, + String encodedWrappedDEK) + { + this.isFooterKey = isFooterKey; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + this.masterKeyID = masterKeyID; + this.isDoubleWrapped = isDoubleWrapped; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + this.encodedWrappedDEK = encodedWrappedDEK; + } + + // parses external key material + static KeyMaterial parse(String keyMaterialString) + { + Map keyMaterialJson = null; + try { + keyMaterialJson = OBJECT_MAPPER.readValue( + new StringReader(keyMaterialString), new TypeReference>() {}); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to parse key metadata " + keyMaterialString, e); + } + // 1. External key material - extract "key material type", and make sure it is supported + String keyMaterialType = (String) keyMaterialJson.get(KEY_MATERIAL_TYPE_FIELD); + if (!KEY_MATERIAL_TYPE1.equals(keyMaterialType)) { + throw new ParquetCryptoRuntimeException( + "Wrong key material type: " + keyMaterialType + " vs " + KEY_MATERIAL_TYPE1); + } + // Parse other fields (common to internal and external key material) + return parse(keyMaterialJson); + } + + // parses fields common to internal and external key material + static KeyMaterial parse(Map keyMaterialJson) + { + // 2. Check if "key material" belongs to file footer key + Boolean isFooterKey = (Boolean) keyMaterialJson.get(IS_FOOTER_KEY_FIELD); + String kmsInstanceID = null; + String kmsInstanceURL = null; + if (isFooterKey) { + // 3. For footer key, extract KMS Instance ID + kmsInstanceID = (String) keyMaterialJson.get(KMS_INSTANCE_ID_FIELD); + // 4. For footer key, extract KMS Instance URL + kmsInstanceURL = (String) keyMaterialJson.get(KMS_INSTANCE_URL_FIELD); + } + // 5. Extract master key ID + String masterKeyID = (String) keyMaterialJson.get(MASTER_KEY_ID_FIELD); + // 6. Extract wrapped DEK + String encodedWrappedDEK = (String) keyMaterialJson.get(WRAPPED_DEK_FIELD); + String kekID = null; + String encodedWrappedKEK = null; + // 7. Check if "key material" was generated in double wrapping mode + Boolean isDoubleWrapped = (Boolean) keyMaterialJson.get(DOUBLE_WRAPPING_FIELD); + if (isDoubleWrapped) { + // 8. In double wrapping mode, extract KEK ID + kekID = (String) keyMaterialJson.get(KEK_ID_FIELD); + // 9. In double wrapping mode, extract wrapped KEK + encodedWrappedKEK = (String) keyMaterialJson.get(WRAPPED_KEK_FIELD); + } + + return new KeyMaterial( + isFooterKey, + kmsInstanceID, + kmsInstanceURL, + masterKeyID, + isDoubleWrapped, + kekID, + encodedWrappedKEK, + encodedWrappedDEK); + } + + static String createSerialized( + boolean isFooterKey, + String kmsInstanceID, + String kmsInstanceURL, + String masterKeyID, + boolean isDoubleWrapped, + String kekID, + String encodedWrappedKEK, + String encodedWrappedDEK, + boolean isInternalStorage) + { + Map keyMaterialMap = new HashMap(10); + // 1. Write "key material type" + keyMaterialMap.put(KEY_MATERIAL_TYPE_FIELD, KEY_MATERIAL_TYPE1); + if (isInternalStorage) { + // for internal storage, key material and key metadata are the same. + // adding the "internalStorage" field that belongs to KeyMetadata. + keyMaterialMap.put(KeyMetadata.KEY_MATERIAL_INTERNAL_STORAGE_FIELD, Boolean.TRUE); + } + // 2. Write isFooterKey + keyMaterialMap.put(IS_FOOTER_KEY_FIELD, isFooterKey); + if (isFooterKey) { + // 3. For footer key, write KMS Instance ID + keyMaterialMap.put(KMS_INSTANCE_ID_FIELD, kmsInstanceID); + // 4. For footer key, write KMS Instance URL + keyMaterialMap.put(KMS_INSTANCE_URL_FIELD, kmsInstanceURL); + } + // 5. Write master key ID + keyMaterialMap.put(MASTER_KEY_ID_FIELD, masterKeyID); + // 6. Write wrapped DEK + keyMaterialMap.put(WRAPPED_DEK_FIELD, encodedWrappedDEK); + // 7. Write isDoubleWrapped + keyMaterialMap.put(DOUBLE_WRAPPING_FIELD, isDoubleWrapped); + if (isDoubleWrapped) { + // 8. In double wrapping mode, write KEK ID + keyMaterialMap.put(KEK_ID_FIELD, kekID); + // 9. In double wrapping mode, write wrapped KEK + keyMaterialMap.put(WRAPPED_KEK_FIELD, encodedWrappedKEK); + } + + try { + return OBJECT_MAPPER.writeValueAsString(keyMaterialMap); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to serialize key material", e); + } + } + + boolean isFooterKey() + { + return isFooterKey; + } + + boolean isDoubleWrapped() + { + return isDoubleWrapped; + } + + String getMasterKeyID() + { + return masterKeyID; + } + + String getWrappedDEK() + { + return encodedWrappedDEK; + } + + String getKekID() + { + return kekID; + } + + String getWrappedKEK() + { + return encodedWrappedKEK; + } + + String getKmsInstanceID() + { + return kmsInstanceID; + } + + String getKmsInstanceURL() + { + return kmsInstanceURL; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java new file mode 100644 index 000000000000..54d4d7227dc9 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/KeyMetadata.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.IOException; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +/** + * Parquet encryption specification defines "key metadata" as an arbitrary byte array, generated by file writers for each encryption key, + * and passed to the low level API for storage in the file footer . The "key metadata" field is made available to file readers to enable + * recovery of the key. This simple interface can be utilized for implementation of any key management scheme. + * + * The keytools package (PARQUET-1373) implements one approach, of many possible, to key management and to generation of the "key metadata" + * fields. This approach, based on the "envelope encryption" pattern, allows to work with KMS servers. It keeps the actual material, + * required to recover a key, in a "key material" object (see the KeyMaterial class for details). + * + * KeyMetadata class writes (and reads) the "key metadata" field as a flat json object, with the following fields: + * 1. "keyMaterialType" - a String, with the type of key material. In the current version, only one value is allowed - "PKMT1" (stands + * for "parquet key management tools, version 1") + * 2. "internalStorage" - a boolean. If true, means that "key material" is kept inside the "key metadata" field. If false, "key material" + * is kept externally (outside Parquet files) - in this case, "key metadata" keeps a reference to the external "key material". + * 3. "keyReference" - a String, with the reference to the external "key material". Written only if internalStorage is false. + * + * If internalStorage is true, "key material" is a part of "key metadata", and the json keeps additional fields, described in the + * KeyMaterial class. + */ +public class KeyMetadata +{ + static final String KEY_MATERIAL_INTERNAL_STORAGE_FIELD = "internalStorage"; + private static final String KEY_REFERENCE_FIELD = "keyReference"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final boolean isInternalStorage; + private final String keyReference; + private final KeyMaterial keyMaterial; + + private KeyMetadata(boolean isInternalStorage, String keyReference, KeyMaterial keyMaterial) + { + this.isInternalStorage = isInternalStorage; + this.keyReference = keyReference; + this.keyMaterial = keyMaterial; + } + + static KeyMetadata parse(byte[] keyMetadataBytes) + { + String keyMetaDataString = new String(keyMetadataBytes, StandardCharsets.UTF_8); + Map keyMetadataJson = null; + try { + keyMetadataJson = OBJECT_MAPPER.readValue( + new StringReader(keyMetaDataString), new TypeReference>() {}); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to parse key metadata " + keyMetaDataString, e); + } + + // 1. Extract "key material type", and make sure it is supported + String keyMaterialType = (String) keyMetadataJson.get(KeyMaterial.KEY_MATERIAL_TYPE_FIELD); + if (!KeyMaterial.KEY_MATERIAL_TYPE1.equals(keyMaterialType)) { + throw new ParquetCryptoRuntimeException( + "Wrong key material type: " + keyMaterialType + " vs " + KeyMaterial.KEY_MATERIAL_TYPE1); + } + + // 2. Check if "key material" is stored internally in Parquet file key metadata, or is stored externally + Boolean isInternalStorage = (Boolean) keyMetadataJson.get(KEY_MATERIAL_INTERNAL_STORAGE_FIELD); + String keyReference; + KeyMaterial keyMaterial; + + if (isInternalStorage) { + // 3.1 "key material" is stored internally, inside "key metadata" - parse it + keyMaterial = KeyMaterial.parse(keyMetadataJson); + keyReference = null; + } + else { + // 3.2 "key material" is stored externally. "key metadata" keeps a reference to it + keyReference = (String) keyMetadataJson.get(KEY_REFERENCE_FIELD); + keyMaterial = null; + } + + return new KeyMetadata(isInternalStorage, keyReference, keyMaterial); + } + + // For external material only. For internal material, create serialized KeyMaterial directly + static String createSerializedForExternalMaterial(String keyReference) + { + Map keyMetadataMap = new HashMap(3); + // 1. Write "key material type" + keyMetadataMap.put(KeyMaterial.KEY_MATERIAL_TYPE_FIELD, KeyMaterial.KEY_MATERIAL_TYPE1); + // 2. Write internal storage as false + keyMetadataMap.put(KEY_MATERIAL_INTERNAL_STORAGE_FIELD, Boolean.FALSE); + // 3. For externally stored "key material", "key metadata" keeps only a reference to it + keyMetadataMap.put(KEY_REFERENCE_FIELD, keyReference); + + try { + return OBJECT_MAPPER.writeValueAsString(keyMetadataMap); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to serialize key metadata", e); + } + } + + boolean keyMaterialStoredInternally() + { + return isInternalStorage; + } + + KeyMaterial getKeyMaterial() + { + return keyMaterial; + } + + String getKeyReference() + { + return keyReference; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java new file mode 100644 index 000000000000..0c5bee3b6da8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java @@ -0,0 +1,164 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.crypto.keytools; + +import com.google.common.base.Strings; +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KeyWithMasterID; + +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KEK_READ_CACHE_PER_TOKEN; +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KMS_CLIENT_CACHE_PER_TOKEN; + +public class TrinoFileKeyUnwrapper + implements DecryptionKeyRetriever +{ + private static final Logger LOG = Logger.get(TrinoFileKeyUnwrapper.class); + + //A map of KEK_ID -> KEK bytes, for the current token + private final ConcurrentMap kekPerKekID; + private final Location parquetFilePath; + // TODO(wyu): shall we get it from Location or File + private final TrinoFileSystem trinoFileSystem; + private final String accessToken; + private final long cacheEntryLifetime; + private final ParquetReaderOptions parquetReaderOptions; + private TrinoKeyToolkit.TrinoKmsClientAndDetails kmsClientAndDetails; + private TrinoHadoopFSKeyMaterialStore keyMaterialStore; + private boolean checkedKeyMaterialInternalStorage; + + TrinoFileKeyUnwrapper(ParquetReaderOptions conf, Location filePath, TrinoFileSystem trinoFileSystem) + { + this.trinoFileSystem = trinoFileSystem; + this.parquetReaderOptions = conf; + this.parquetFilePath = filePath; + this.cacheEntryLifetime = 1000L * conf.getEncryptionCacheLifetimeSeconds(); + this.accessToken = conf.getEncryptionKeyAccessToken(); + this.kmsClientAndDetails = null; + this.keyMaterialStore = null; + this.checkedKeyMaterialInternalStorage = false; + + // Check cache upon each file reading (clean once in cacheEntryLifetime) + KMS_CLIENT_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + KEK_READ_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + kekPerKekID = KEK_READ_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + if (LOG.isDebugEnabled()) { + LOG.debug("Creating file key unwrapper. KeyMaterialStore: {}; token snippet: {}", + keyMaterialStore, TrinoKeyToolkit.formatTokenForLog(accessToken)); + } + } + + @Override + public byte[] getKey(byte[] keyMetadataBytes) + { + KeyMetadata keyMetadata = KeyMetadata.parse(keyMetadataBytes); + + if (!checkedKeyMaterialInternalStorage) { + if (!keyMetadata.keyMaterialStoredInternally()) { + keyMaterialStore = new TrinoHadoopFSKeyMaterialStore(trinoFileSystem, parquetFilePath, false); + } + checkedKeyMaterialInternalStorage = true; + } + + KeyMaterial keyMaterial; + if (keyMetadata.keyMaterialStoredInternally()) { + // Internal key material storage: key material is inside key metadata + keyMaterial = keyMetadata.getKeyMaterial(); + } + else { + // External key material storage: key metadata contains a reference to a key in the material store + String keyIDinFile = keyMetadata.getKeyReference(); + String keyMaterialString = keyMaterialStore.getKeyMaterial(keyIDinFile); + if (null == keyMaterialString) { + throw new ParquetCryptoRuntimeException("Null key material for keyIDinFile: " + keyIDinFile); + } + keyMaterial = KeyMaterial.parse(keyMaterialString); + } + + return getDEKandMasterID(keyMaterial).getDataKey(); + } + + KeyWithMasterID getDEKandMasterID(KeyMaterial keyMaterial) + { + if (null == kmsClientAndDetails) { + kmsClientAndDetails = getKmsClientFromConfigOrKeyMaterial(keyMaterial); + } + + boolean doubleWrapping = keyMaterial.isDoubleWrapped(); + String masterKeyID = keyMaterial.getMasterKeyID(); + String encodedWrappedDEK = keyMaterial.getWrappedDEK(); + + byte[] dataKey; + TrinoKmsClient kmsClient = kmsClientAndDetails.getKmsClient(); + if (!doubleWrapping) { + dataKey = kmsClient.unwrapKey(encodedWrappedDEK, masterKeyID); + } + else { + // Get KEK + String encodedKekID = keyMaterial.getKekID(); + String encodedWrappedKEK = keyMaterial.getWrappedKEK(); + + byte[] kekBytes = kekPerKekID.computeIfAbsent(encodedKekID, + (k) -> kmsClient.unwrapKey(encodedWrappedKEK, masterKeyID)); + + if (null == kekBytes) { + throw new ParquetCryptoRuntimeException("Null KEK, after unwrapping in KMS with master key " + masterKeyID); + } + + // Decrypt the data key + byte[] aad = Base64.getDecoder().decode(encodedKekID); + dataKey = TrinoKeyToolkit.decryptKeyLocally(encodedWrappedDEK, kekBytes, aad); + } + + return new KeyWithMasterID(dataKey, masterKeyID); + } + + TrinoKeyToolkit.TrinoKmsClientAndDetails getKmsClientFromConfigOrKeyMaterial(KeyMaterial keyMaterial) + { + String kmsInstanceID = this.parquetReaderOptions.getEncryptionKmsInstanceId(); + if (Strings.isNullOrEmpty(kmsInstanceID)) { + kmsInstanceID = keyMaterial.getKmsInstanceID(); + if (null == kmsInstanceID) { + throw new ParquetCryptoRuntimeException("KMS instance ID is missing both in properties and file key material"); + } + } + + String kmsInstanceURL = this.parquetReaderOptions.getEncryptionKmsInstanceUrl(); + if (Strings.isNullOrEmpty(kmsInstanceURL)) { + kmsInstanceURL = keyMaterial.getKmsInstanceURL(); + if (null == kmsInstanceURL) { + throw new ParquetCryptoRuntimeException("KMS instance URL is missing both in properties and file key material"); + } + } + + TrinoKmsClient kmsClient = TrinoKeyToolkit.getKmsClient(kmsInstanceID, kmsInstanceURL, this.parquetReaderOptions, accessToken, cacheEntryLifetime); + if (null == kmsClient) { + throw new ParquetCryptoRuntimeException("KMSClient was not successfully created for reading encrypted data."); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("File unwrapper - KmsClient: {}; InstanceId: {}; InstanceURL: {}", kmsClient, kmsInstanceID, kmsInstanceURL); + } + return new TrinoKeyToolkit.TrinoKmsClientAndDetails(kmsClient, kmsInstanceID, kmsInstanceURL); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java new file mode 100644 index 000000000000..4c178c0bd8fe --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Map; + +public class TrinoHadoopFSKeyMaterialStore +{ + public static final String KEY_MATERIAL_FILE_PREFIX = "_KEY_MATERIAL_FOR_"; + public static final String TEMP_FILE_PREFIX = "_TMP"; + public static final String KEY_MATERIAL_FILE_SUFFFIX = ".json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private TrinoFileSystem trinoFileSystem; + private Map keyMaterialMap; + private Location keyMaterialFile; + + TrinoHadoopFSKeyMaterialStore(TrinoFileSystem trinoFileSystem, Location parquetFilePath, boolean tempStore) + { + this.trinoFileSystem = trinoFileSystem; + String fullPrefix = (tempStore ? TEMP_FILE_PREFIX : ""); + fullPrefix += KEY_MATERIAL_FILE_PREFIX; + keyMaterialFile = parquetFilePath.parentDirectory().appendPath( + fullPrefix + parquetFilePath.fileName() + KEY_MATERIAL_FILE_SUFFFIX); + } + + public String getKeyMaterial(String keyIDInFile) + throws ParquetCryptoRuntimeException + { + if (null == keyMaterialMap) { + loadKeyMaterialMap(); + } + return keyMaterialMap.get(keyIDInFile); + } + + private void loadKeyMaterialMap() + { + TrinoInputFile inputfile = trinoFileSystem.newInputFile(keyMaterialFile); + try (TrinoInputStream keyMaterialStream = inputfile.newStream()) { + JsonNode keyMaterialJson = objectMapper.readTree(keyMaterialStream); + keyMaterialMap = objectMapper.readValue(keyMaterialJson.traverse(), + new TypeReference>() {}); + } + catch (FileNotFoundException e) { + throw new ParquetCryptoRuntimeException("External key material not found at " + keyMaterialFile, e); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to get key material from " + keyMaterialFile, e); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java new file mode 100644 index 000000000000..eb05702732ba --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.AesGcmDecryptor; +import io.trino.parquet.crypto.AesMode; +import io.trino.parquet.crypto.ModuleCipherFactory; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoCryptoConfigurationUtil; + +import java.lang.reflect.InvocationTargetException; +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +public class TrinoKeyToolkit +{ + public static final long CACHE_LIFETIME_DEFAULT_SECONDS = 10 * 60; // 10 minutes + + // KMS client two level cache: token -> KMSInstanceId -> KmsClient + static final TwoLevelCacheWithExpiration KMS_CLIENT_CACHE_PER_TOKEN = + KmsClientCache.INSTANCE.getCache(); + + // KEK two level cache for unwrapping: token -> KEK_ID -> KEK bytes + static final TwoLevelCacheWithExpiration KEK_READ_CACHE_PER_TOKEN = + KEKReadCache.INSTANCE.getCache(); + + private TrinoKeyToolkit() + { + } + + private enum KmsClientCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + private enum KEKReadCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + static String formatTokenForLog(String accessToken) + { + int maxTokenDisplayLength = 5; + if (accessToken.length() <= maxTokenDisplayLength) { + return accessToken; + } + return accessToken.substring(accessToken.length() - maxTokenDisplayLength); + } + + static class KeyWithMasterID + { + private final byte[] keyBytes; + private final String masterID; + + KeyWithMasterID(byte[] keyBytes, String masterID) + { + this.keyBytes = keyBytes; + this.masterID = masterID; + } + + byte[] getDataKey() + { + return keyBytes; + } + + String getMasterID() + { + return masterID; + } + } + + static class KeyEncryptionKey + { + private final byte[] kekBytes; + private final byte[] kekID; + private String encodedKekID; + private final String encodedWrappedKEK; + + KeyEncryptionKey(byte[] kekBytes, byte[] kekID, String encodedWrappedKEK) + { + this.kekBytes = kekBytes; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + } + + byte[] getBytes() + { + return kekBytes; + } + + byte[] getID() + { + return kekID; + } + + String getEncodedID() + { + if (null == encodedKekID) { + encodedKekID = Base64.getEncoder().encodeToString(kekID); + } + return encodedKekID; + } + + String getEncodedWrappedKEK() + { + return encodedWrappedKEK; + } + } + + /** + * Decrypts encrypted key with "masterKey", using AES-GCM and the "aad" + * + * @param encodedEncryptedKey base64 encoded encrypted key + * @param masterKeyBytes encryption key + * @param aad additional authenticated data + * @return decrypted key + */ + public static byte[] decryptKeyLocally(String encodedEncryptedKey, byte[] masterKeyBytes, byte[] aad) + { + byte[] encryptedKey = Base64.getDecoder().decode(encodedEncryptedKey); + + AesGcmDecryptor keyDecryptor; + + keyDecryptor = (AesGcmDecryptor) ModuleCipherFactory.getDecryptor(AesMode.GCM, masterKeyBytes); + + return keyDecryptor.decrypt(encryptedKey, 0, encryptedKey.length, aad); + } + + static TrinoKmsClient getKmsClient(String kmsInstanceID, String kmsInstanceURL, ParquetReaderOptions trinoParquetCryptoConfig, + String accessToken, long cacheEntryLifetime) + { + ConcurrentMap kmsClientPerKmsInstanceCache = + KMS_CLIENT_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + TrinoKmsClient kmsClient = + kmsClientPerKmsInstanceCache.computeIfAbsent(kmsInstanceID, + (k) -> createAndInitKmsClient(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken)); + + return kmsClient; + } + + private static TrinoKmsClient createAndInitKmsClient(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, + String kmsInstanceURL, String accessToken) + { + Class> kmsClientClass = null; + TrinoKmsClient kmsClient; + + try { + kmsClientClass = TrinoCryptoConfigurationUtil.getClassFromConfig(trinoParquetCryptoConfig.getEncryptionKmsClientClass(), + TrinoKmsClient.class); + + if (null == kmsClientClass) { + throw new ParquetCryptoRuntimeException("Could not find class " + trinoParquetCryptoConfig.getEncryptionKmsClientClass()); + } + kmsClient = (TrinoKmsClient) kmsClientClass.getConstructor().newInstance(); + } + catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + throw new ParquetCryptoRuntimeException("Could not instantiate KmsClient class: " + + kmsClientClass, e); + } + + kmsClient.initialize(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken); + + return kmsClient; + } + + static class TrinoKmsClientAndDetails + { + public TrinoKmsClient getKmsClient() + { + return kmsClient; + } + + private TrinoKmsClient kmsClient; + private String kmsInstanceID; + private String kmsInstanceURL; + + public TrinoKmsClientAndDetails(TrinoKmsClient kmsClient, String kmsInstanceID, String kmsInstanceURL) + { + this.kmsClient = kmsClient; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + } + + public String getKmsInstanceID() + { + return kmsInstanceID; + } + + public String getKmsInstanceURL() + { + return kmsInstanceURL; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java new file mode 100644 index 000000000000..6ca6cb0cb53e --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.KeyAccessDeniedException; + +public interface TrinoKmsClient +{ + String KEY_ACCESS_TOKEN_DEFAULT = "DEFAULT"; + + void initialize(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, String kmsInstanceURL, String accessToken) + throws KeyAccessDeniedException; + + String wrapKey(byte[] keyBytes, String masterKeyIdentifier) + throws KeyAccessDeniedException; + + byte[] unwrapKey(String wrappedKey, String masterKeyIdentifier) + throws KeyAccessDeniedException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java new file mode 100644 index 000000000000..8eb61c18c0e8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoDecryptionPropertiesFactory; + +public class TrinoPropertiesDrivenCryptoFactory + implements TrinoDecryptionPropertiesFactory +{ + private static final Logger LOG = Logger.get(TrinoPropertiesDrivenCryptoFactory.class); + + @Override + public FileDecryptionProperties getFileDecryptionProperties(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException + { + DecryptionKeyRetriever keyRetriever = new TrinoFileKeyUnwrapper(parquetReaderOptions, filePath, trinoFileSystem); + + if (LOG.isDebugEnabled()) { + LOG.debug("File decryption properties for {}", filePath); + } + + return FileDecryptionProperties.builder() + .withKeyRetriever(keyRetriever) + .withPlaintextFilesAllowed() + .build(); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java new file mode 100644 index 000000000000..ca2e7d2d356d --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * Concurrent two-level cache with expiration of internal caches according to token lifetime. + * External cache is per token, internal is per String key. + * + * @param Value + */ +class TwoLevelCacheWithExpiration +{ + private final ConcurrentMap>> cache; + private volatile long lastCacheCleanupTimestamp; + + TwoLevelCacheWithExpiration() + { + this.cache = new ConcurrentHashMap<>(); + this.lastCacheCleanupTimestamp = System.currentTimeMillis(); + } + + ConcurrentMap getOrCreateInternalCache(String accessToken, long cacheEntryLifetime) + { + ExpiringCacheEntry> externalCacheEntry = + cache.compute(accessToken, (token, cacheEntry) -> { + if ((null == cacheEntry) || cacheEntry.isExpired()) { + return new ExpiringCacheEntry<>(new ConcurrentHashMap(), cacheEntryLifetime); + } + else { + return cacheEntry; + } + }); + return externalCacheEntry.getCachedItem(); + } + + void removeCacheEntriesForToken(String accessToken) + { + cache.remove(accessToken); + } + + void removeCacheEntriesForAllTokens() + { + cache.clear(); + } + + public void checkCacheForExpiredTokens(long cacheCleanupPeriod) + { + long now = System.currentTimeMillis(); + + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + synchronized (cache) { + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + removeExpiredEntriesFromCache(); + lastCacheCleanupTimestamp = now + cacheCleanupPeriod; + } + } + } + } + + public void removeExpiredEntriesFromCache() + { + cache.values().removeIf(cacheEntry -> cacheEntry.isExpired()); + } + + public void remove(String accessToken) + { + cache.remove(accessToken); + } + + public void clear() + { + cache.clear(); + } + + static class ExpiringCacheEntry + { + private final long expirationTimestamp; + private final E cachedItem; + + private ExpiringCacheEntry(E cachedItem, long expirationIntervalMillis) + { + this.expirationTimestamp = System.currentTimeMillis() + expirationIntervalMillis; + this.cachedItem = cachedItem; + } + + private boolean isExpired() + { + final long now = System.currentTimeMillis(); + return (now > expirationTimestamp); + } + + private E getCachedItem() + { + return cachedItem; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java index 43defc21b834..1a955515fe50 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java @@ -15,7 +15,7 @@ import java.util.List; -public record BlockMetadata(long rowCount, List columns) +public record BlockMetadata(long rowCount, long totalByteSize, short ordinal, List columns) { public long getStartingPos() { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java index 381260829869..0c9c85c95aee 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java @@ -23,6 +23,9 @@ import java.util.Set; +import static io.trino.parquet.ParquetEncoding.PLAIN_DICTIONARY; +import static io.trino.parquet.ParquetEncoding.RLE_DICTIONARY; + public abstract class ColumnChunkMetadata { protected int rowGroupOrdinal = -1; @@ -200,4 +203,16 @@ public String toString() decryptIfNeeded(); return "ColumnMetaData{" + properties.toString() + ", " + getFirstDataPageOffset() + "}"; } + + public boolean hasDictionaryPage() + { + EncodingStats stats = getEncodingStats(); + if (stats != null) { + // ensure there is a dictionary page and that it is used to encode data pages + return stats.hasDictionaryPages() && stats.hasDictionaryEncodedPages(); + } + + Set encodings = getEncodings(); + return (encodings.contains(PLAIN_DICTIONARY) || encodings.contains(RLE_DICTIONARY)); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java index 6901bb23a4e6..3293e980e719 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; @@ -230,9 +231,11 @@ private static Map> getStatistics(PrunedBlockMet ImmutableMap.Builder> statistics = ImmutableMap.builderWithExpectedSize(descriptorsByPath.size()); for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - Statistics> columnStatistics = columnMetaData.getStatistics(); - if (columnStatistics != null) { - statistics.put(descriptor, columnStatistics); + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + if (columnStatistics != null) { + statistics.put(descriptor, columnStatistics); + } } } return statistics.buildOrThrow(); @@ -260,18 +263,20 @@ private static boolean dictionaryPredicatesMatch( { for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - if (!candidateColumns.contains(descriptor)) { - continue; - } - if (isOnlyDictionaryEncodingPages(columnMetaData)) { - Statistics> columnStatistics = columnMetaData.getStatistics(); - boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; - // Early abort, predicate already filters block so no more dictionaries need be read - if (!parquetPredicate.matches(new DictionaryDescriptor( - descriptor, - nullAllowed, - readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { - return false; + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + if (!candidateColumns.contains(descriptor)) { + continue; + } + if (isOnlyDictionaryEncodingPages(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; + // Early abort, predicate already filters block so no more dictionaries need be read + if (!parquetPredicate.matches(new DictionaryDescriptor( + descriptor, + nullAllowed, + readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { + return false; + } } } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index fe0635646f98..294cfe0604b2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -15,26 +15,41 @@ import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; +import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.AesGcmEncryptor; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.KeyAccessDeniedException; +import io.trino.parquet.crypto.ModuleCipherFactory.ModuleType; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TagVerificationException; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.format.BlockCipher.Decryptor; import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnMetaData; import org.apache.parquet.format.Encoding; +import org.apache.parquet.format.EncryptionWithColumnKey; +import org.apache.parquet.format.FileCryptoMetaData; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.SchemaElement; import org.apache.parquet.format.Statistics; +import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -43,6 +58,7 @@ import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -56,7 +72,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetMetadataConverter.convertEncodingStats; import static io.trino.parquet.ParquetMetadataConverter.fromParquetStatistics; import static io.trino.parquet.ParquetMetadataConverter.getEncoding; @@ -69,6 +87,7 @@ import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static org.apache.parquet.format.Util.readFileCryptoMetaData; import static org.apache.parquet.format.Util.readFileMetaData; public final class MetadataReader @@ -76,13 +95,14 @@ public final class MetadataReader private static final Logger log = Logger.get(MetadataReader.class); private static final Slice MAGIC = Slices.utf8Slice("PAR1"); + private static final Slice EMAGIC = Slices.utf8Slice("PARE"); private static final int POST_SCRIPT_SIZE = Integer.BYTES + MAGIC.length(); // Typical 1GB files produced by Trino were found to have footer size between 30-40KB private static final int EXPECTED_FOOTER_SIZE = 48 * 1024; private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation, Optional fileDecryptor) throws IOException { // Parquet File Layout: @@ -93,7 +113,9 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< // 4 bytes: MetadataLength // MAGIC - validateParquet(dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE, dataSource.getId(), "%s is not a valid Parquet File", dataSource.getId()); + validateParquet((dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE) || + (dataSource.getEstimatedSize() >= EMAGIC.length() + POST_SCRIPT_SIZE), dataSource.getId(), + "%s is not a valid Parquet File", dataSource.getId()); // Read the tail of the file long estimatedFileSize = dataSource.getEstimatedSize(); @@ -101,8 +123,10 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< Slice buffer = dataSource.readTail(toIntExact(expectedReadSize)); Slice magic = buffer.slice(buffer.length() - MAGIC.length(), MAGIC.length()); - validateParquet(MAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + validateParquet(MAGIC.equals(magic) || EMAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s or %s got: %s", MAGIC.toStringUtf8(), EMAGIC.toStringUtf8(), magic.toStringUtf8()); + boolean encryptedFooterMode = EMAGIC.equals(magic); + checkArgument(!encryptedFooterMode || !(fileDecryptor.isEmpty() || fileDecryptor.get().getDecryptionProperties() == null), "fileDecryptionProperties cannot be null when encryptedFooterMode is true"); int metadataLength = buffer.getInt(buffer.length() - POST_SCRIPT_SIZE); long metadataIndex = estimatedFileSize - POST_SCRIPT_SIZE - metadataLength; validateParquet( @@ -118,13 +142,44 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< } InputStream metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); - FileMetaData fileMetaData = readFileMetaData(metadataStream); - ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, dataSource.getId()); + Decryptor footerDecryptor = null; + byte[] aad = null; + + if (encryptedFooterMode) { + FileCryptoMetaData fileCryptoMetaData = readFileCryptoMetaData(metadataStream); + fileDecryptor.get().setFileCryptoMetaData(fileCryptoMetaData.getEncryption_algorithm(), true, fileCryptoMetaData.getKey_metadata()); + footerDecryptor = fileDecryptor.get().fetchFooterDecryptor(); + aad = AesCipher.createFooterAAD(fileDecryptor.get().getFileAAD()); + } + FileMetaData fileMetaData = readFileMetaData(metadataStream, footerDecryptor, aad); + if (!encryptedFooterMode && fileDecryptor.isPresent()) { + if (!fileMetaData.isSetEncryption_algorithm()) { // Plaintext file + fileDecryptor.get().setPlaintextFile(); + // Done to detect files that were not encrypted by mistake + if (!fileDecryptor.get().plaintextFilesAllowed()) { + throw new ParquetCryptoRuntimeException("Applying decryptor on plaintext file"); + } + } + else { // Encrypted file with plaintext footer + // if no fileDecryptor, can still read plaintext columns + fileDecryptor.get().setFileCryptoMetaData(fileMetaData.getEncryption_algorithm(), false, + fileMetaData.getFooter_signing_key_metadata()); + if (fileDecryptor.get().checkFooterIntegrity()) { + verifyFooterIntegrity(metadataStream, fileDecryptor.get(), metadataLength); + } + } + } + ParquetDataSourceId id = dataSource.getId(); + ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, id, fileDecryptor, encryptedFooterMode); + validateFileMetadata(id, parquetMetadata.getFileMetaData(), parquetWriteValidation); validateFileMetadata(dataSource.getId(), parquetMetadata.getFileMetaData(), parquetWriteValidation); return parquetMetadata; } - public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, ParquetDataSourceId dataSourceId) + public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, + ParquetDataSourceId dataSourceId, + Optional fileDecryptor, + boolean encryptedFooterMode) throws ParquetCorruptionException { List schema = fileMetaData.getSchema(); @@ -138,37 +193,79 @@ public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, P List columns = rowGroup.getColumns(); validateParquet(!columns.isEmpty(), dataSourceId, "No columns in row group: %s", rowGroup); String filePath = columns.get(0).getFile_path(); + int columnOrdinal = -1; ImmutableList.Builder columnMetadataBuilder = ImmutableList.builderWithExpectedSize(columns.size()); for (ColumnChunk columnChunk : columns) { + columnOrdinal++; validateParquet( (filePath == null && columnChunk.getFile_path() == null) || (filePath != null && filePath.equals(columnChunk.getFile_path())), dataSourceId, "all column chunks of the same row group must be in the same file"); + ColumnCryptoMetaData cryptoMetaData = columnChunk.getCrypto_metadata(); ColumnMetaData metaData = columnChunk.meta_data; - String[] path = metaData.path_in_schema.stream() - .map(value -> value.toLowerCase(Locale.ENGLISH)) - .toArray(String[]::new); - ColumnPath columnPath = ColumnPath.get(path); - PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); - ColumnChunkMetadata column = ColumnChunkMetadata.get( - columnPath, - primitiveType, - CompressionCodecName.fromParquet(metaData.codec), - convertEncodingStats(metaData.encoding_stats), - readEncodings(metaData.encodings), - readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), - metaData.data_page_offset, - metaData.dictionary_page_offset, - metaData.num_values, - metaData.total_compressed_size, - metaData.total_uncompressed_size); - column.setColumnIndexReference(toColumnIndexReference(columnChunk)); - column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); - column.setBloomFilterOffset(metaData.bloom_filter_offset); - columnMetadataBuilder.add(column); + ColumnPath columnPath = null; + boolean encryptedMetadata = false; + if (cryptoMetaData == null) { + columnPath = getPath(metaData); + if (fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile()) { + // mark this column as plaintext in encrypted file decryptor + fileDecryptor.get().setColumnCryptoMetadata(columnPath, false, false, (byte[]) null, columnOrdinal); + } + } + else { // Encrypted column + if (cryptoMetaData.isSetENCRYPTION_WITH_FOOTER_KEY()) { // Column encrypted with footer key + if (!encryptedFooterMode) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key in file with plaintext footer"); + } + if (null == metaData) { + throw new ParquetCryptoRuntimeException("ColumnMetaData not set in Encryption with Footer key"); + } + if (fileDecryptor.isEmpty()) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key: No keys available"); + } + columnPath = getPath(metaData); + fileDecryptor.get().setColumnCryptoMetadata(columnPath, true, true, (byte[]) null, columnOrdinal); + } + else { // Column encrypted with column key + encryptedMetadata = true; + } + } + try { + if (encryptedMetadata) { + // TODO: We decrypted data before filter projection. This could send unnecessary traffic to KMS. + // In parquet-mr, it uses lazy decyrption but that required to change ColumnChunkMetadata. We will improve it alter. + metaData = decryptMetadata(rowGroup, cryptoMetaData, columnChunk, fileDecryptor.get(), columnOrdinal); + columnPath = getPath(metaData); + } + PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); + ColumnChunkMetadata column = ColumnChunkMetadata.get( + columnPath, + primitiveType, + CompressionCodecName.fromParquet(metaData.codec), + convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + column.setColumnIndexReference(toColumnIndexReference(columnChunk)); + column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); + column.setBloomFilterOffset(metaData.bloom_filter_offset); + + if (rowGroup.isSetOrdinal()) { + column.setRowGroupOrdinal(rowGroup.getOrdinal()); + } + columnMetadataBuilder.add(column); + } + catch (KeyAccessDeniedException e) { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(columnPath, filePath); + columnMetadataBuilder.add(column); + } } - blocks.add(new BlockMetadata(rowGroup.getNum_rows(), columnMetadataBuilder.build())); + blocks.add(new BlockMetadata(rowGroup.getNum_rows(), rowGroup.getTotal_byte_size(), rowGroup.getOrdinal(), columnMetadataBuilder.build())); } } @@ -274,6 +371,25 @@ public static org.apache.parquet.column.statistics.Statistics> readStats(Optio return columnStatistics; } + /** + * If a column is encrypted and user doesn't provide correct key to decrypt, that column is hidden to current request. + * This method find out the first non-hidden column. + * + * @param block BlockMetaData + * @return first non hidden column id. + */ + public static Integer findFirstNonHiddenColumnId(BlockMetadata block) + { + List columns = block.columns(); + for (int i = 0; i < columns.size(); i++) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(columns.get(i))) { + return i; + } + } + // all columns are hidden (encrypted but not accessible to current user) + return null; + } + private static boolean isStringType(PrimitiveType type) { if (type.getLogicalTypeAnnotation() == null) { @@ -373,4 +489,75 @@ private static void validateFileMetadata(ParquetDataSourceId dataSourceId, FileM Optional.ofNullable(fileMetaData.getKeyValueMetaData().get("writer.time.zone"))); writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); } + + private static ColumnMetaData decryptMetadata(RowGroup rowGroup, ColumnCryptoMetaData cryptoMetaData, ColumnChunk columnChunk, InternalFileDecryptor fileDecryptor, int columnOrdinal) + { + EncryptionWithColumnKey columnKeyStruct = cryptoMetaData.getENCRYPTION_WITH_COLUMN_KEY(); + List pathList = columnKeyStruct.getPath_in_schema().stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()); + + byte[] columnKeyMetadata = columnKeyStruct.getKey_metadata(); + ColumnPath columnPath = ColumnPath.get(pathList.toArray(new String[pathList.size()])); + byte[] encryptedMetadataBuffer = columnChunk.getEncrypted_column_metadata(); + + // Decrypt the ColumnMetaData + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.setColumnCryptoMetadata(columnPath, true, false, columnKeyMetadata, columnOrdinal); + ByteArrayInputStream tempInputStream = new ByteArrayInputStream(encryptedMetadataBuffer); + byte[] columnMetaDataAAD = AesCipher.createModuleAAD(fileDecryptor.getFileAAD(), ModuleType.ColumnMetaData, rowGroup.ordinal, columnOrdinal, -1); + try { + return Util.readColumnMetaData(tempInputStream, columnDecryptionSetup.getMetaDataDecryptor(), columnMetaDataAAD); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException(columnPath + ". Failed to decrypt column metadata", e); + } + } + + /*public static ColumnChunkMetadata buildColumnChunkMetaData(Optional fileCreatedBy, ColumnMetaData metaData, ColumnPath columnPath, PrimitiveType type) + { + return ColumnChunkMetadata.get( + columnPath, + type, + CompressionCodecName.fromParquet(metaData.codec), + PARQUET_METADATA_CONVERTER.convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(fileCreatedBy, Optional.ofNullable(metaData.statistics), type), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + }*/ + + private static ColumnPath getPath(ColumnMetaData metaData) + { + String[] path = metaData.path_in_schema.stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)) + .toArray(String[]::new); + return ColumnPath.get(path); + } + + private static void verifyFooterIntegrity(InputStream metadataStream, InternalFileDecryptor fileDecryptor, int combinedFooterLength) + throws IOException + { + byte[] nonce = new byte[AesCipher.NONCE_LENGTH]; + metadataStream.read(nonce); + byte[] gcmTag = new byte[AesCipher.GCM_TAG_LENGTH]; + metadataStream.read(gcmTag); + + AesGcmEncryptor footerSigner = fileDecryptor.createSignedFooterEncryptor(); + int footerSignatureLength = AesCipher.NONCE_LENGTH + AesCipher.GCM_TAG_LENGTH; + byte[] serializedFooter = new byte[combinedFooterLength - footerSignatureLength]; + + //InputStream doesn't implement reset(). Here is to workaround + ((BasicSliceInput) metadataStream).setPosition(0); + metadataStream.read(serializedFooter, 0, serializedFooter.length); + + byte[] signedFooterAAD = AesCipher.createFooterAAD(fileDecryptor.getFileAAD()); + byte[] encryptedFooterBytes = footerSigner.encrypt(false, serializedFooter, nonce, signedFooterAAD); + byte[] calculatedTag = new byte[AesCipher.GCM_TAG_LENGTH]; + System.arraycopy(encryptedFooterBytes, encryptedFooterBytes.length - AesCipher.GCM_TAG_LENGTH, calculatedTag, 0, AesCipher.GCM_TAG_LENGTH); + if (!Arrays.equals(gcmTag, calculatedTag)) { + throw new TagVerificationException("Signature mismatch in plaintext footer"); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index d8ec35c52fbe..799d4b111654 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -16,17 +16,24 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; import io.trino.parquet.DataPage; import io.trino.parquet.DataPageV1; import io.trino.parquet.DataPageV2; import io.trino.parquet.DictionaryPage; import io.trino.parquet.Page; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.CompressionCodec; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -35,6 +42,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetCompressionUtils.decompress; import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages; import static java.util.Objects.requireNonNull; @@ -49,6 +57,10 @@ public final class PageReader private boolean dictionaryAlreadyRead; private int dataPageReadCount; + private int pageIndex; + private final BlockCipher.Decryptor blockDecryptor; + private byte[] dataPageAAD; + private byte[] dictionaryPageAAD; public static PageReader createPageReader( ParquetDataSourceId dataSourceId, @@ -56,7 +68,8 @@ public static PageReader createPageReader( ColumnChunkMetadata metadata, ColumnDescriptor columnDescriptor, @Nullable OffsetIndex offsetIndex, - Optional fileCreatedBy) + Optional fileCreatedBy, + Optional fileDecryptor) { // Parquet schema may specify a column definition as OPTIONAL even though there are no nulls in the actual data. // Row-group column statistics can be used to identify such cases and switch to faster non-nullable read @@ -64,20 +77,36 @@ public static PageReader createPageReader( Statistics> columnStatistics = metadata.getStatistics(); boolean hasNoNulls = columnStatistics != null && columnStatistics.getNumNulls() == 0; boolean hasOnlyDictionaryEncodedPages = isOnlyDictionaryEncodingPages(metadata); + byte[] fileAad = null; + BlockCipher.Decryptor dataDecryptor = null; + int columnOrdinal = -1; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + fileAad = fileDecryptor.get().getFileAAD(); + dataDecryptor = columnDecryptionSetup.getDataDecryptor(); + columnOrdinal = columnDecryptionSetup.getOrdinal(); + } ParquetColumnChunkIterator compressedPages = new ParquetColumnChunkIterator( dataSourceId, fileCreatedBy, columnDescriptor, metadata, columnChunk, - offsetIndex); + offsetIndex, + fileDecryptor, + columnOrdinal); return new PageReader( dataSourceId, metadata.getCodec().getParquetCompressionCodec(), compressedPages, hasOnlyDictionaryEncodedPages, - hasNoNulls); + hasNoNulls, + dataDecryptor, + fileAad, + metadata.getRowGroupOrdinal(), + columnOrdinal); } @VisibleForTesting @@ -86,13 +115,22 @@ public PageReader( CompressionCodec codec, Iterator extends Page> compressedPages, boolean hasOnlyDictionaryEncodedPages, - boolean hasNoNulls) + boolean hasNoNulls, + BlockCipher.Decryptor blockDecryptor, + byte[] fileAAD, + int rowGroupOrdinal, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.codec = codec; this.compressedPages = Iterators.peekingIterator(compressedPages); this.hasOnlyDictionaryEncodedPages = hasOnlyDictionaryEncodedPages; this.hasNoNulls = hasNoNulls; + this.blockDecryptor = blockDecryptor; + if (null != blockDecryptor) { + dataPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DataPage, rowGroupOrdinal, columnOrdinal, 0); + dictionaryPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DictionaryPage, rowGroupOrdinal, columnOrdinal, -1); + } } public boolean hasNoNulls() @@ -114,18 +152,23 @@ public DataPage readPage() checkState(compressedPage instanceof DataPage, "Found page %s instead of a DataPage", compressedPage); dataPageReadCount++; try { + if (null != blockDecryptor) { + AesCipher.quickUpdatePageAAD(dataPageAAD, ((DataPage) compressedPage).getPageIndex()); + } + Slice slice = decryptSliceIfNeeded(compressedPage.getSlice(), dataPageAAD); if (compressedPage instanceof DataPageV1 dataPageV1) { if (!arePagesCompressed()) { return dataPageV1; } return new DataPageV1( - decompress(dataSourceId, codec, dataPageV1.getSlice(), dataPageV1.getUncompressedSize()), + decompress(dataSourceId, codec, slice, dataPageV1.getUncompressedSize()), dataPageV1.getValueCount(), dataPageV1.getUncompressedSize(), dataPageV1.getFirstRowIndex(), dataPageV1.getRepetitionLevelEncoding(), dataPageV1.getDefinitionLevelEncoding(), - dataPageV1.getValueEncoding()); + dataPageV1.getValueEncoding(), + dataPageV1.getPageIndex()); } DataPageV2 dataPageV2 = (DataPageV2) compressedPage; if (!dataPageV2.isCompressed()) { @@ -141,11 +184,12 @@ public DataPage readPage() dataPageV2.getRepetitionLevels(), dataPageV2.getDefinitionLevels(), dataPageV2.getDataEncoding(), - decompress(dataSourceId, codec, dataPageV2.getSlice(), uncompressedSize), + decompress(dataSourceId, codec, slice, uncompressedSize), dataPageV2.getUncompressedSize(), dataPageV2.getFirstRowIndex(), dataPageV2.getStatistics(), - false); + false, + dataPageV2.getPageIndex()); } catch (IOException e) { throw new RuntimeException("Could not decompress page", e); @@ -162,8 +206,9 @@ public DictionaryPage readDictionaryPage() } try { DictionaryPage compressedDictionaryPage = (DictionaryPage) compressedPages.next(); + Slice slice = decryptSliceIfNeeded(compressedDictionaryPage.getSlice(), dictionaryPageAAD); return new DictionaryPage( - decompress(dataSourceId, codec, compressedDictionaryPage.getSlice(), compressedDictionaryPage.getUncompressedSize()), + decompress(dataSourceId, codec, slice, compressedDictionaryPage.getUncompressedSize()), compressedDictionaryPage.getDictionarySize(), compressedDictionaryPage.getEncoding()); } @@ -199,4 +244,14 @@ private void verifyDictionaryPageRead() { checkArgument(dictionaryAlreadyRead, "Dictionary has to be read first"); } + + private Slice decryptSliceIfNeeded(Slice slice, byte[] aad) + throws IOException + { + if (blockDecryptor == null) { + return slice; + } + byte[] plainText = blockDecryptor.decrypt(slice.getBytes(), aad); + return wrappedBuffer(plainText); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java index 235c1b2d3d76..720d5f16151f 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java @@ -19,15 +19,21 @@ import io.trino.parquet.Page; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.PageHeader; import org.apache.parquet.format.Util; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -51,6 +57,9 @@ public final class ParquetColumnChunkIterator private long valueCount; private int dataPageCount; + private Optional fileDecryptor; + private int columnOrdinal; + private boolean dictionaryWasRead; public ParquetColumnChunkIterator( ParquetDataSourceId dataSourceId, @@ -58,7 +67,9 @@ public ParquetColumnChunkIterator( ColumnDescriptor descriptor, ColumnChunkMetadata metadata, ChunkedInputStream input, - @Nullable OffsetIndex offsetIndex) + @Nullable OffsetIndex offsetIndex, + Optional fileDecryptor, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -66,6 +77,8 @@ public ParquetColumnChunkIterator( this.metadata = requireNonNull(metadata, "metadata is null"); this.input = requireNonNull(input, "input is null"); this.offsetIndex = offsetIndex; + this.fileDecryptor = fileDecryptor; + this.columnOrdinal = columnOrdinal; } @Override @@ -79,8 +92,32 @@ public Page next() { checkState(hasNext(), "No more data left to read in column (%s), metadata (%s), valueCount %s, dataPageCount %s", descriptor, metadata, valueCount, dataPageCount); + byte[] dataPageHeaderAAD = null; + BlockCipher.Decryptor headerBlockDecryptor = null; + InternalColumnDecryptionSetup columnDecryptionSetup = null; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(descriptor.getPath()); + columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + headerBlockDecryptor = columnDecryptionSetup.getMetaDataDecryptor(); + if (null != headerBlockDecryptor) { + dataPageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DataPageHeader, metadata.getRowGroupOrdinal(), columnOrdinal, dataPageCount); + } + } try { - PageHeader pageHeader = readPageHeader(); + byte[] pageHeaderAAD = dataPageHeaderAAD; + if (null != headerBlockDecryptor) { + // Important: this verifies file integrity (makes sure dictionary page had not been removed) + if (!(dictionaryWasRead || !metadata.hasDictionaryPage())) { + pageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DictionaryPageHeader, metadata.getRowGroupOrdinal(), + columnOrdinal, -1); + } + else { + AesCipher.quickUpdatePageAAD(dataPageHeaderAAD, dataPageCount); + } + } + PageHeader pageHeader = readPageHeader(headerBlockDecryptor, pageHeaderAAD); int uncompressedPageSize = pageHeader.getUncompressed_page_size(); int compressedPageSize = pageHeader.getCompressed_page_size(); Page result = null; @@ -90,13 +127,14 @@ public Page next() throw new ParquetCorruptionException(dataSourceId, "Column (%s) has a dictionary page after the first position in column chunk", descriptor); } result = readDictionaryPage(pageHeader, pageHeader.getUncompressed_page_size(), pageHeader.getCompressed_page_size()); + dictionaryWasRead = true; break; case DATA_PAGE: - result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; case DATA_PAGE_V2: - result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; default: @@ -110,10 +148,10 @@ public Page next() } } - private PageHeader readPageHeader() + private PageHeader readPageHeader(BlockCipher.Decryptor headerBlockDecryptor, byte[] pageHeaderAAD) throws IOException { - return Util.readPageHeader(input); + return Util.readPageHeader(input, headerBlockDecryptor, pageHeaderAAD); } private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoFar) @@ -139,7 +177,8 @@ private DataPageV1 readDataPageV1( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeader dataHeaderV1 = pageHeader.getData_page_header(); @@ -151,14 +190,16 @@ private DataPageV1 readDataPageV1( firstRowIndex, getParquetEncoding(Encoding.valueOf(dataHeaderV1.getRepetition_level_encoding().name())), getParquetEncoding(Encoding.valueOf(dataHeaderV1.getDefinition_level_encoding().name())), - getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name()))); + getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name())), + pageIndex); } private DataPageV2 readDataPageV2( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeaderV2 dataHeaderV2 = pageHeader.getData_page_header_v2(); @@ -178,7 +219,8 @@ private DataPageV2 readDataPageV2( fileCreatedBy, Optional.ofNullable(dataHeaderV2.getStatistics()), descriptor.getPrimitiveType()), - dataHeaderV2.isIs_compressed()); + dataHeaderV2.isIs_compressed(), + pageIndex); } private static OptionalLong getFirstRowIndex(int pageIndex, OffsetIndex offsetIndex) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 0ad000ccd420..128375e5a32e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -30,6 +30,9 @@ import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -129,6 +132,7 @@ public class ParquetReader private final Map> codecMetrics; private long columnIndexRowsFiltered = -1; + private final Optional fileDecryptor; public ParquetReader( Optional fileCreatedBy, @@ -140,7 +144,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - Optional writeValidation) + Optional writeValidation, + Optional fileDecryptor) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -156,6 +161,7 @@ public ParquetReader( this.maxBatchSize = options.getMaxReadBlockRowCount(); this.columnReaders = new HashMap<>(); this.maxBytesPerCell = new HashMap<>(); + this.fileDecryptor = fileDecryptor; this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); validateWrite( @@ -264,7 +270,7 @@ public long lastBatchStartRow() return firstRowIndexInGroup + nextRowInGroup - batchSize; } - private int nextBatch() + public int nextBatch() throws IOException { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { @@ -457,9 +463,16 @@ private ColumnChunk readPrimitive(PrimitiveField field) offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath()); } ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup)); - columnReader.setPageReader( - createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy), - Optional.ofNullable(rowRanges)); + if (isEncryptedColumn(fileDecryptor, columnDescriptor)) { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } + else { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } } ColumnChunk columnChunk = columnReader.readPrimitive(); @@ -491,6 +504,19 @@ public Metrics getMetrics() return new Metrics(metrics.buildOrThrow()); } + private ColumnChunkMetadata getColumnChunkMetaData(BlockMetadata blockMetaData, ColumnDescriptor columnDescriptor) + throws IOException + { + for (ColumnChunkMetadata metadata : blockMetaData.columns()) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(metadata)) { + if (metadata.getPath().equals(ColumnPath.get(columnDescriptor.getPath()))) { + return metadata; + } + } + } + throw new ParquetCorruptionException(dataSource.getId(), "Metadata is missing for column: %s", columnDescriptor); + } + private void initializeColumnReaders() { for (PrimitiveField field : primitiveFields) { @@ -612,4 +638,10 @@ private void validateWrite(java.util.function.Predicate throw new ParquetCorruptionException(dataSource.getId(), "Write validation failed: " + messageFormat, args); } } + + private boolean isEncryptedColumn(Optional fileDecryptor, ColumnDescriptor columnDescriptor) + { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + return fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile() && fileDecryptor.get().getColumnSetup(columnPath).isEncrypted(); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 651d86040ef5..9eb40a5665e4 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -237,7 +237,7 @@ public void validate(ParquetDataSource input) checkState(validationBuilder.isPresent(), "validation is not enabled"); ParquetWriteValidation writeValidation = validationBuilder.get().build(); try { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation), Optional.empty()); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page @@ -293,7 +293,8 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new RuntimeException(exception); }, Optional.empty(), - Optional.of(writeValidation)); + Optional.of(writeValidation), + Optional.empty()); } private void recordValidation(Consumer task) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java index 9f7918115838..e6cdd9825e77 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java @@ -225,7 +225,7 @@ public void setup() testData.getColumnNames(), testData.getPages()), new ParquetReaderOptions()); - parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); columnNames = columns.stream() .map(TpchColumn::getColumnName) .collect(toImmutableList()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index febdaccf617b..59280c6de102 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -164,6 +164,7 @@ public static ParquetReader createParquetReader( return new RuntimeException(exception); }, Optional.of(parquetPredicate), + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java index fc47c42d8d82..448ef7dc26a8 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java @@ -105,7 +105,7 @@ public int read() throws IOException { ColumnReader columnReader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false, null, null, -1, -1); columnReader.setPageReader(pageReader, Optional.empty()); int rowsRead = 0; while (rowsRead < dataPositions) { @@ -133,7 +133,8 @@ private DataPage createDataPage(ValuesWriter writer, int valuesCount) OptionalLong.empty(), RLE, RLE, - getParquetEncoding(writer.getEncoding())); + getParquetEncoding(writer.getEncoding()), + 0); } protected static void run(Class> clazz) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 6a3fccb1e281..37dde42f5e57 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -564,7 +564,11 @@ else if (dictionaryEncoding == DictionaryEncoding.MIXED) { UNCOMPRESSED, inputPages.iterator(), dictionaryEncoding == DictionaryEncoding.ALL || (dictionaryEncoding == DictionaryEncoding.MIXED && testingPages.size() == 1), - false); + false, + null, + null, + -1, + -1); } private static List createDataPages(List testingPages, ValuesWriter encoder, int maxDef, boolean required) @@ -599,7 +603,8 @@ private static DataPage createDataPage(TestingPage testingPage, ValuesWriter enc valueCount * 4, OptionalLong.of(testingPage.pageRowRange().start()), null, - false); + false, + 0); encoder.reset(); return dataPage; } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 445b61268c33..8b8fe067c88a 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -660,7 +660,8 @@ protected static DataPage createDataPage( OptionalLong.empty(), getParquetEncoding(repetitionWriter.getEncoding()), getParquetEncoding(definitionWriter.getEncoding()), - encoding); + encoding, + 0); } return new DataPageV2( valueCount, @@ -673,7 +674,8 @@ protected static DataPage createDataPage( definitionBytes.length + repetitionBytes.length + valueBytes.length, OptionalLong.empty(), null, - false); + false, + 0); } protected static PageReader getPageReaderMock(List dataPages, @Nullable DictionaryPage dictionaryPage) @@ -699,7 +701,7 @@ protected static PageReader getPageReaderMock(List dataPages, @Nullabl return ((DataPageV2) page).getDataEncoding(); }) .allMatch(encoding -> encoding == PLAIN_DICTIONARY || encoding == RLE_DICTIONARY), - hasNoNulls); + hasNoNulls, null, null, -1, -1); } private DataPage createDataPage(DataPageVersion version, ParquetEncoding encoding, ValuesWriter writer, int valueCount) @@ -713,7 +715,7 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin { Slice slice = Slices.wrappedBuffer(writer.getBytes().toByteArray()); if (version == V1) { - return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding); + return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding, 0); } return new DataPageV2( valueCount, @@ -726,7 +728,8 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin slice.length(), firstRowIndex, null, - false); + false, + 0); } private static ValuesWriter getLevelsWriter(int maxLevel, int valueCount) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java new file mode 100644 index 000000000000..ac6981666b57 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.crypto.ColumnEncryptionProperties; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.FileEncryptionProperties; +import io.trino.parquet.crypto.ParquetCipher; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EncDecPropertiesHelper +{ + private EncDecPropertiesHelper() + { + } + + private static class DecryptionKeyRetrieverMock + implements DecryptionKeyRetriever + { + private final Map keyMap = new HashMap<>(); + + public DecryptionKeyRetrieverMock putKey(String keyId, byte[] keyBytes) + { + keyMap.put(keyId, keyBytes); + return this; + } + + @Override + public byte[] getKey(byte[] keyMetaData) + { + String keyId = new String(keyMetaData, StandardCharsets.UTF_8); + return keyMap.get(keyId); + } + } + + private static final byte[] FOOTER_KEY = {0x01, 0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}; + private static final byte[] FOOTER_KEY_METADATA = "footkey".getBytes(StandardCharsets.UTF_8); + private static final byte[] COL_KEY = {0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11}; + private static final byte[] COL_KEY_METADATA = "col".getBytes(StandardCharsets.UTF_8); + + public static FileDecryptionProperties getFileDecryptionProperties() + throws IOException + { + DecryptionKeyRetrieverMock keyRetriever = new DecryptionKeyRetrieverMock(); + keyRetriever.putKey("footkey", FOOTER_KEY); + keyRetriever.putKey("col", COL_KEY); + return FileDecryptionProperties.builder().withPlaintextFilesAllowed().withKeyRetriever(keyRetriever).build(); + } + + public static FileEncryptionProperties getFileEncryptionProperties(List encryptColumns, ParquetCipher cipher, Boolean encryptFooter) + { + if (encryptColumns.size() == 0) { + return null; + } + + Map columnPropertyMap = new HashMap<>(); + for (String encryptColumn : encryptColumns) { + ColumnPath columnPath = ColumnPath.fromDotString(encryptColumn); + ColumnEncryptionProperties columnEncryptionProperties = ColumnEncryptionProperties.builder(columnPath) + .withKey(COL_KEY) + .withKeyMetaData(COL_KEY_METADATA) + .build(); + columnPropertyMap.put(columnPath, columnEncryptionProperties); + } + + FileEncryptionProperties.Builder encryptionPropertiesBuilder = + FileEncryptionProperties.builder(FOOTER_KEY) + .withFooterKeyMetadata(FOOTER_KEY_METADATA) + .withAlgorithm(cipher) + .withEncryptedColumns(columnPropertyMap); + + if (!encryptFooter) { + encryptionPropertiesBuilder.withPlaintextFooter(); + } + + return encryptionPropertiesBuilder.build(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java new file mode 100644 index 000000000000..d7677525ef13 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import org.apache.parquet.example.data.simple.SimpleGroup; + +public class EncryptionTestFile +{ + private final String fileName; + private final SimpleGroup[] fileContent; + + public EncryptionTestFile(String fileName, SimpleGroup[] fileContent) + { + this.fileName = fileName; + this.fileContent = fileContent; + } + + public String getFileName() + { + return this.fileName; + } + + public SimpleGroup[] getFileContent() + { + return this.fileContent; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java new file mode 100644 index 000000000000..dd46ccb689b0 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class MockInputStreamTail +{ + public static final int MAX_SUPPORTED_PADDING_BYTES = 64; + private static final int MAXIMUM_READ_LENGTH = Integer.MAX_VALUE - (MAX_SUPPORTED_PADDING_BYTES + 1); + + private final Slice tailSlice; + private final long fileSize; + + private MockInputStreamTail(long fileSize, Slice tailSlice) + { + this.tailSlice = requireNonNull(tailSlice, "tailSlice is null"); + this.fileSize = fileSize; + checkArgument(fileSize >= 0, "fileSize is negative: %s", fileSize); + checkArgument(tailSlice.length() <= fileSize, "length (%s) is greater than fileSize (%s)", tailSlice.length(), fileSize); + } + + public static MockInputStreamTail readTail(String path, long paddedFileSize, FSDataInputStream inputStream, int length) + throws IOException + { + checkArgument(length >= 0, "length is negative: %s", length); + checkArgument(length <= MAXIMUM_READ_LENGTH, "length (%s) exceeds maximum (%s)", length, MAXIMUM_READ_LENGTH); + long readSize = min(paddedFileSize, (length + MAX_SUPPORTED_PADDING_BYTES)); + long position = paddedFileSize - readSize; + // Actual read will be 1 byte larger to ensure we encounter an EOF where expected + byte[] buffer = new byte[toIntExact(readSize + 1)]; + int bytesRead = 0; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + while (bytesRead < buffer.length) { + int n = inputStream.read(buffer, bytesRead, buffer.length - bytesRead); + if (n < 0) { + break; + } + bytesRead += n; + } + } + finally { + inputStream.seek(startPos); + } + if (bytesRead > readSize) { + throw rejectInvalidFileSize(path, paddedFileSize); + } + return new MockInputStreamTail(position + bytesRead, Slices.wrappedBuffer(buffer, max(0, bytesRead - length), min(bytesRead, length))); + } + + public static long readTailForFileSize(String path, long paddedFileSize, FSDataInputStream inputStream) + throws IOException + { + long position = max(paddedFileSize - MAX_SUPPORTED_PADDING_BYTES, 0); + long maxEOFAt = paddedFileSize + 1; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + int c; + while (position < maxEOFAt) { + c = inputStream.read(); + if (c < 0) { + return position; + } + position++; + } + throw rejectInvalidFileSize(path, paddedFileSize); + } + finally { + inputStream.seek(startPos); + } + } + + private static IOException rejectInvalidFileSize(String path, long reportedSize) + throws IOException + { + throw new IOException(format("Incorrect file size (%s) for file (end of stream not reached): %s", reportedSize, path)); + } + + public long getFileSize() + { + return fileSize; + } + + public Slice getTailSlice() + { + return tailSlice; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java new file mode 100644 index 000000000000..2652e2da3301 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java @@ -0,0 +1,335 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.units.DataSize; +import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.parquet.ChunkReader; +import io.trino.parquet.DiskRange; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.Math.toIntExact; +import static java.util.Comparator.comparingLong; +import static java.util.Objects.requireNonNull; + +public class MockParquetDataSource + implements ParquetDataSource +{ + private final ParquetDataSourceId id; + private final long estimatedSize; + private final FSDataInputStream inputStream; + private long readTimeNanos; + private long readBytes; + private final ParquetReaderOptions options; + + public MockParquetDataSource( + ParquetDataSourceId id, + long estimatedSize, + FSDataInputStream inputStream, + ParquetReaderOptions options) + { + this.id = requireNonNull(id, "id is null"); + this.estimatedSize = estimatedSize; + this.inputStream = inputStream; + this.options = requireNonNull(options, "options is null"); + } + + @Override + public ParquetDataSourceId getId() + { + return id; + } + + @Override + public final long getReadBytes() + { + return readBytes; + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos; + } + + @Override + public final long getEstimatedSize() + { + return estimatedSize; + } + + @Override + public void close() + throws IOException + { + inputStream.close(); + } + + @Override + public Slice readTail(int length) + { + long start = System.nanoTime(); + Slice tailSlice; + try { + // Handle potentially imprecise file lengths by reading the footer + MockInputStreamTail fileTail = MockInputStreamTail.readTail(getId().toString(), getEstimatedSize(), inputStream, length); + tailSlice = fileTail.getTailSlice(); + } + catch (IOException e) { + throw new RuntimeException("Error reading tail from %s with length"); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + readBytes += tailSlice.length(); + return tailSlice; + } + + @Override + public final Slice readFully(long position, int length) + { + byte[] buffer = new byte[length]; + readFully(position, buffer, 0, length); + return Slices.wrappedBuffer(buffer); + } + + @Override + public final Map planRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableMap.of(); + } + + return planChunksRead(diskRanges, memoryContext).asMap() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> new ChunkedInputStream(entry.getValue()))); + } + + public ListMultimap planChunksRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + // + // Note: this code does not use the stream APIs to avoid any extra object allocation + // + + // split disk ranges into "big" and "small" + ImmutableListMultimap.Builder smallRangesBuilder = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder largeRangesBuilder = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + if (entry.getValue().getLength() <= options.getMaxBufferSize().toBytes()) { + smallRangesBuilder.put(entry); + } + else { + largeRangesBuilder.put(entry); + } + } + ListMultimap smallRanges = smallRangesBuilder.build(); + ListMultimap largeRanges = largeRangesBuilder.build(); + + // read ranges + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + slices.putAll(readSmallDiskRanges(smallRanges)); + slices.putAll(readLargeDiskRanges(largeRanges)); + + return slices.build(); + } + + private void readFully(long position, byte[] buffer, int bufferOffset, int bufferLength) + { + readBytes += bufferLength; + + long start = System.nanoTime(); + try { + inputStream.readFully(position, buffer, bufferOffset, bufferLength); + } + catch (Exception e) { + throw new RuntimeException("Error reading from %s " + id + " at position " + position); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + } + + private ListMultimap readSmallDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + Iterable mergedRanges = mergeAdjacentDiskRanges(diskRanges.values(), options.getMaxMergeDistance(), options.getMaxBufferSize()); + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (DiskRange mergedRange : mergedRanges) { + ReferenceCountedReader mergedRangeLoader = new ReferenceCountedReader(mergedRange); + + for (Map.Entry diskRangeEntry : diskRanges.entries()) { + DiskRange diskRange = diskRangeEntry.getValue(); + if (mergedRange.contains(diskRange)) { + mergedRangeLoader.addReference(); + + slices.put(diskRangeEntry.getKey(), new ChunkReader() + { + @Override + public Slice read() + { + int offset = toIntExact(diskRange.getOffset() - mergedRange.getOffset()); + return mergedRangeLoader.read().slice(offset, Long.valueOf(diskRange.getLength()).intValue()); + } + + @Override + public void free() + { + mergedRangeLoader.free(); + } + + @Override + public long getDiskOffset() + { + return diskRange.getOffset(); + } + }); + } + } + + mergedRangeLoader.free(); + } + + ListMultimap sliceStreams = slices.build(); + verify(sliceStreams.keySet().equals(diskRanges.keySet())); + return sliceStreams; + } + + private ListMultimap readLargeDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + slices.put(entry.getKey(), new ReferenceCountedReader(entry.getValue())); + } + return slices.build(); + } + + private static List mergeAdjacentDiskRanges(Collection diskRanges, DataSize maxMergeDistance, DataSize maxReadSize) + { + // sort ranges by start offset + List ranges = new ArrayList<>(diskRanges); + ranges.sort(comparingLong(DiskRange::getOffset)); + + long maxReadSizeBytes = maxReadSize.toBytes(); + long maxMergeDistanceBytes = maxMergeDistance.toBytes(); + + // merge overlapping ranges + ImmutableList.Builder result = ImmutableList.builder(); + DiskRange last = ranges.get(0); + for (int i = 1; i < ranges.size(); i++) { + DiskRange current = ranges.get(i); + DiskRange merged = null; + boolean blockTooLong = false; + try { + merged = last.span(current); + } + catch (ArithmeticException e) { + blockTooLong = true; + } + if (!blockTooLong && merged.getLength() <= maxReadSizeBytes && last.getEnd() + maxMergeDistanceBytes >= current.getOffset()) { + last = merged; + } + else { + result.add(last); + last = current; + } + } + result.add(last); + + return result.build(); + } + + private class ReferenceCountedReader + implements ChunkReader + { + private final DiskRange range; + private Slice data; + private int referenceCount = 1; + + public ReferenceCountedReader(DiskRange range) + { + this.range = range; + } + + public void addReference() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + referenceCount++; + } + + @Override + public Slice read() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + + if (data == null) { + byte[] buffer = new byte[Long.valueOf(range.getLength()).intValue()]; + readFully(range.getOffset(), buffer, 0, buffer.length); + data = Slices.wrappedBuffer(buffer); + } + + return data; + } + + @Override + public void free() + { + checkState(referenceCount > 0, "Reference count is already 0"); + + referenceCount--; + if (referenceCount == 0) { + data = null; + } + } + + @Override + public long getDiskOffset() + { + return range.getOffset(); + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java index d42725e5acb2..7f448bdbed2d 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java @@ -50,7 +50,7 @@ public void testReadFloatDouble() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("byte_stream_split_float_and_double.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); readAndCompare(reader, getExpectedValues()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..c178d5be0261 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableSet; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.HiddenColumnException; +import io.trino.parquet.metadata.ColumnChunkMetadata; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Types; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.parquet.column.Encoding.PLAIN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestHiddenColumnChunkMetaData +{ + @Test + public void testIsHiddenColumn() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isTrue(); + } + + @Test + public void testIsNotHiddenColumn() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isFalse(); + } + + @Test(expectedExceptions = HiddenColumnException.class) + public void testHiddenColumnException() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + column.getStatistics(); + } + + @Test + public void testNoHiddenColumnException() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + column.getStatistics(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index aabb734e5b0c..49e4fc2f9d80 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -112,7 +112,7 @@ public void testNanosOutsideDayRange() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("int96_timestamps_nanos_outside_day_range.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); @@ -166,11 +166,12 @@ private void testVariousTimestamps(TimestampType type) slice.length(), OptionalLong.empty(), null, - false); + false, + 0); // Read and assert ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(DateTimeZone.UTC, new ParquetReaderOptions()); ColumnReader reader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false, null, null, -1, -1); reader.setPageReader(pageReader, Optional.empty()); reader.prepareNextRead(valueCount); Block block = reader.readPrimitive().getBlock(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java index 102e2b4fc01b..a94ff78cf8f2 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetTypeUtils; +import io.trino.parquet.crypto.InternalFileDecryptor; import io.trino.parquet.metadata.ColumnChunkMetadata; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.EncodingStats; @@ -183,7 +184,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP out.write(compressedDataPage); byte[] bytes = out.toByteArray(); - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); DictionaryPage uncompressedDictionaryPage = pageReader.readDictionaryPage(); assertThat(uncompressedDictionaryPage.getDictionarySize()).isEqualTo(dictionaryPageHeader.getDictionary_page_header().getNum_values()); assertEncodingEquals(uncompressedDictionaryPage.getEncoding(), dictionaryPageHeader.getDictionary_page_header().getEncoding()); @@ -193,7 +194,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP assertPages(compressionCodec, totalValueCount, 3, pageHeader, compressedDataPage, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); // only dictionary - pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize)))); + pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize))), null, -1); assertThatThrownBy(pageReader::readDictionaryPage) .isInstanceOf(IllegalStateException.class) .hasMessageStartingWith("No more data left to read"); @@ -236,7 +237,7 @@ public void dictionaryPageNotFirst() int totalValueCount = valueCount * 2; // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNull(); assertThat(pageReader.readPage()).isNotNull(); assertThatThrownBy(pageReader::readPage) @@ -270,7 +271,7 @@ public void unusedDictionaryPage() byte[] bytes = out.toByteArray(); // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNotNull(); assertThat(pageReader.readPage()).isNotNull(); assertThat(pageReader.readPage()).isNull(); @@ -298,7 +299,7 @@ private static void assertPages( List slices) throws IOException { - PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices); + PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices, null, -1); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); assertThat(dictionaryPage != null).isEqualTo(hasDictionary); @@ -383,7 +384,7 @@ private static byte[] compress(CompressionCodec compressionCodec, byte[] bytes, throw new IllegalArgumentException("unsupported compression code " + compressionCodec); } - private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices) + private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices, InternalFileDecryptor fileDecryptor, int rowGroupOrdinal) throws IOException { EncodingStats.Builder encodingStats = new EncodingStats.Builder(); @@ -409,7 +410,8 @@ private static PageReader createPageReader(int valueCount, CompressionCodec comp columnChunkMetaData, new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, INT32, ""), 0, 0), null, - Optional.empty()); + Optional.empty(), + Optional.ofNullable(fileDecryptor)); } private static void assertDataPageEquals(PageHeader pageHeader, byte[] dataPage, byte[] compressedDataPage, DataPage decompressedPage) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java index 2ef475a7644f..0c4f3011dbb1 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java @@ -79,7 +79,7 @@ public void testColumnReaderMemoryUsage() columnNames, generateInputPages(types, 100, 5)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThan(1); // Verify file has only non-dictionary encodings as dictionary memory usage is already tested in TestFlatColumnReader#testMemoryUsage parquetMetadata.getBlocks().forEach(block -> { @@ -132,7 +132,7 @@ public void testEmptyRowRangesWithColumnIndex() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("lineitem_sorted_by_shipdate/data.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(2); // The predicate and the file are prepared so that page indexes will result in non-overlapping row ranges and eliminate the entire first row group // while the second row group still has to be read @@ -193,7 +193,7 @@ private void testReadingOldParquetFiles(File file, List columnNames, Typ file, new ParquetReaderOptions()); ConnectorSession session = TestingConnectorSession.builder().build(); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), ImmutableList.of(columnType), columnNames)) { Page page = reader.nextPage(); Iterator> expected = expectedValues.iterator(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index 390608f445a9..99ae226bca08 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -60,7 +60,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("time_millis_int32.snappy.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java index a3efb46b6d71..8222899ab90b 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java @@ -137,8 +137,9 @@ private static PageReader getSimplePageReaderMock(ParquetEncoding encoding) OptionalLong.empty(), encoding, encoding, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } private static PageReader getNullOnlyPageReaderMock() @@ -154,7 +155,8 @@ private static PageReader getNullOnlyPageReaderMock() OptionalLong.empty(), RLE, RLE, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index 846080c3297a..717474419d11 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -127,7 +127,7 @@ public void testWrittenPageSize() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -141,6 +141,7 @@ public void testWrittenPageSize() chunkMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -176,7 +177,7 @@ public void testWrittenPageValueCount() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -194,6 +195,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -213,6 +215,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columnb"}, new PrimitiveType(REQUIRED, INT64, "columnb"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -256,8 +259,7 @@ public void testLargeStringTruncation() columnNames, ImmutableList.of(new Page(2, blockA, blockB))), new ParquetReaderOptions()); - - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); BlockMetadata blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetadata chunkMetaData = blockMetaData.columns().get(0); @@ -290,7 +292,7 @@ public void testColumnReordering() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(10); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { // Verify that the columns are stored in the same order as the metadata @@ -347,7 +349,7 @@ public void testDictionaryPageOffset() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(1); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { ColumnChunkMetadata chunkMetaData = getOnlyElement(blockMetaData.columns()); @@ -393,7 +395,7 @@ public void testWriteBloomFilters(Type type, List> data) generateInputPages(types, 100, data)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); // Check that bloom filters are right after each other int bloomFilterSize = Integer.highestOneBit(BlockSplitBloomFilter.optimalNumOfBits(BLOOM_FILTER_EXPECTED_ENTRIES, DEFAULT_BLOOM_FILTER_FPP) / 8) << 1; for (BlockMetadata block : parquetMetadata.getBlocks()) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index 5fe764a72756..eb0a41cd8108 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -362,7 +362,7 @@ private Slice writeMergeResult(Slice path, FileDeletion deletion) TrinoInputFile inputFile = fileSystem.newInputFile(Location.of(path.toStringUtf8())); try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, parquetReaderOptions, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); long rowCount = parquetMetadata.getBlocks().stream().map(BlockMetadata::rowCount).mapToLong(Long::longValue).sum(); RoaringBitmapArray rowsRetained = new RoaringBitmapArray(); rowsRetained.addRange(0, rowCount - 1); @@ -637,7 +637,8 @@ private ReaderPageSource createParquetPageSource(Location path) new ParquetReaderOptions().withBloomFilter(false), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + null); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index f08ecc84f839..c552b1944e2c 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -254,7 +254,8 @@ public ConnectorPageSource createPageSource( options, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(split.getFileSize())); + OptionalLong.of(split.getFileSize()), + null); Optional projectionsAdapter = pageSource.getReaderColumns().map(readerColumns -> new ReaderProjectionsAdapter( @@ -306,7 +307,7 @@ private PositionDeleteFilter readDeletes( public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index 8f686205e239..5330c6edd100 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -184,7 +184,7 @@ public DataFileInfo getDataFileInfo() { Location path = rootTableLocation.appendPath(relativeFilePath); FileMetaData fileMetaData = fileWriter.getFileMetadata(); - ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString())); + ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString()), Optional.empty(), false); return new DataFileInfo( relativeFilePath, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index 7f5d4b8a88c6..3c57de2ef2f3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -205,7 +205,8 @@ private static DeltaLakePageSource createDeltaLakePageSource( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.empty()); + OptionalLong.of(split.fileSize()), + null); verify(pageSource.getReaderColumns().isEmpty(), "Unexpected reader columns: %s", pageSource.getReaderColumns().orElse(null)); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 04673aeab8ea..985cc433aaea 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -231,7 +231,8 @@ public CheckpointEntryIterator( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + Optional.empty()); this.pageSource = (ParquetPageSource) pageSource.get(); try { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index 70cdce9c5e4f..7f9050feb512 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -329,7 +329,7 @@ private void testOptimizeWithColumnMappingMode(String columnMappingMode) TrinoInputFile inputFile = new LocalInputFile(tableLocation.resolve(addFileEntry.getPath()).toFile()); ParquetMetadata parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); PrimitiveType physicalType = getOnlyElement(fileMetaData.getSchema().getColumns().iterator()).getPrimitiveType(); assertThat(physicalType.getName()).isEqualTo(physicalName); diff --git a/plugin/trino-geospatial/pom.xml b/plugin/trino-geospatial/pom.xml index 6d975cb4232a..1636b6f25102 100644 --- a/plugin/trino-geospatial/pom.xml +++ b/plugin/trino-geospatial/pom.xml @@ -230,4 +230,20 @@ test
+ * The keytools package (PARQUET-1373) implements one approach, of many possible, to key management and to generation of the "key metadata" + * fields. This approach, based on the "envelope encryption" pattern, allows to work with KMS servers. It keeps the actual material, + * required to recover a key, in a "key material" object (see the KeyMaterial class for details). + *
+ * KeyMetadata class writes (and reads) the "key metadata" field as a flat json object, with the following fields: + * 1. "keyMaterialType" - a String, with the type of key material. In the current version, only one value is allowed - "PKMT1" (stands + * for "parquet key management tools, version 1") + * 2. "internalStorage" - a boolean. If true, means that "key material" is kept inside the "key metadata" field. If false, "key material" + * is kept externally (outside Parquet files) - in this case, "key metadata" keeps a reference to the external "key material". + * 3. "keyReference" - a String, with the reference to the external "key material". Written only if internalStorage is false. + *
+ * If internalStorage is true, "key material" is a part of "key metadata", and the json keeps additional fields, described in the + * KeyMaterial class. + */ +public class KeyMetadata +{ + static final String KEY_MATERIAL_INTERNAL_STORAGE_FIELD = "internalStorage"; + private static final String KEY_REFERENCE_FIELD = "keyReference"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final boolean isInternalStorage; + private final String keyReference; + private final KeyMaterial keyMaterial; + + private KeyMetadata(boolean isInternalStorage, String keyReference, KeyMaterial keyMaterial) + { + this.isInternalStorage = isInternalStorage; + this.keyReference = keyReference; + this.keyMaterial = keyMaterial; + } + + static KeyMetadata parse(byte[] keyMetadataBytes) + { + String keyMetaDataString = new String(keyMetadataBytes, StandardCharsets.UTF_8); + Map keyMetadataJson = null; + try { + keyMetadataJson = OBJECT_MAPPER.readValue( + new StringReader(keyMetaDataString), new TypeReference>() {}); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to parse key metadata " + keyMetaDataString, e); + } + + // 1. Extract "key material type", and make sure it is supported + String keyMaterialType = (String) keyMetadataJson.get(KeyMaterial.KEY_MATERIAL_TYPE_FIELD); + if (!KeyMaterial.KEY_MATERIAL_TYPE1.equals(keyMaterialType)) { + throw new ParquetCryptoRuntimeException( + "Wrong key material type: " + keyMaterialType + " vs " + KeyMaterial.KEY_MATERIAL_TYPE1); + } + + // 2. Check if "key material" is stored internally in Parquet file key metadata, or is stored externally + Boolean isInternalStorage = (Boolean) keyMetadataJson.get(KEY_MATERIAL_INTERNAL_STORAGE_FIELD); + String keyReference; + KeyMaterial keyMaterial; + + if (isInternalStorage) { + // 3.1 "key material" is stored internally, inside "key metadata" - parse it + keyMaterial = KeyMaterial.parse(keyMetadataJson); + keyReference = null; + } + else { + // 3.2 "key material" is stored externally. "key metadata" keeps a reference to it + keyReference = (String) keyMetadataJson.get(KEY_REFERENCE_FIELD); + keyMaterial = null; + } + + return new KeyMetadata(isInternalStorage, keyReference, keyMaterial); + } + + // For external material only. For internal material, create serialized KeyMaterial directly + static String createSerializedForExternalMaterial(String keyReference) + { + Map keyMetadataMap = new HashMap(3); + // 1. Write "key material type" + keyMetadataMap.put(KeyMaterial.KEY_MATERIAL_TYPE_FIELD, KeyMaterial.KEY_MATERIAL_TYPE1); + // 2. Write internal storage as false + keyMetadataMap.put(KEY_MATERIAL_INTERNAL_STORAGE_FIELD, Boolean.FALSE); + // 3. For externally stored "key material", "key metadata" keeps only a reference to it + keyMetadataMap.put(KEY_REFERENCE_FIELD, keyReference); + + try { + return OBJECT_MAPPER.writeValueAsString(keyMetadataMap); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to serialize key metadata", e); + } + } + + boolean keyMaterialStoredInternally() + { + return isInternalStorage; + } + + KeyMaterial getKeyMaterial() + { + return keyMaterial; + } + + String getKeyReference() + { + return keyReference; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java new file mode 100644 index 000000000000..0c5bee3b6da8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java @@ -0,0 +1,164 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.crypto.keytools; + +import com.google.common.base.Strings; +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KeyWithMasterID; + +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KEK_READ_CACHE_PER_TOKEN; +import static io.trino.parquet.crypto.keytools.TrinoKeyToolkit.KMS_CLIENT_CACHE_PER_TOKEN; + +public class TrinoFileKeyUnwrapper + implements DecryptionKeyRetriever +{ + private static final Logger LOG = Logger.get(TrinoFileKeyUnwrapper.class); + + //A map of KEK_ID -> KEK bytes, for the current token + private final ConcurrentMap kekPerKekID; + private final Location parquetFilePath; + // TODO(wyu): shall we get it from Location or File + private final TrinoFileSystem trinoFileSystem; + private final String accessToken; + private final long cacheEntryLifetime; + private final ParquetReaderOptions parquetReaderOptions; + private TrinoKeyToolkit.TrinoKmsClientAndDetails kmsClientAndDetails; + private TrinoHadoopFSKeyMaterialStore keyMaterialStore; + private boolean checkedKeyMaterialInternalStorage; + + TrinoFileKeyUnwrapper(ParquetReaderOptions conf, Location filePath, TrinoFileSystem trinoFileSystem) + { + this.trinoFileSystem = trinoFileSystem; + this.parquetReaderOptions = conf; + this.parquetFilePath = filePath; + this.cacheEntryLifetime = 1000L * conf.getEncryptionCacheLifetimeSeconds(); + this.accessToken = conf.getEncryptionKeyAccessToken(); + this.kmsClientAndDetails = null; + this.keyMaterialStore = null; + this.checkedKeyMaterialInternalStorage = false; + + // Check cache upon each file reading (clean once in cacheEntryLifetime) + KMS_CLIENT_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + KEK_READ_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + kekPerKekID = KEK_READ_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + if (LOG.isDebugEnabled()) { + LOG.debug("Creating file key unwrapper. KeyMaterialStore: {}; token snippet: {}", + keyMaterialStore, TrinoKeyToolkit.formatTokenForLog(accessToken)); + } + } + + @Override + public byte[] getKey(byte[] keyMetadataBytes) + { + KeyMetadata keyMetadata = KeyMetadata.parse(keyMetadataBytes); + + if (!checkedKeyMaterialInternalStorage) { + if (!keyMetadata.keyMaterialStoredInternally()) { + keyMaterialStore = new TrinoHadoopFSKeyMaterialStore(trinoFileSystem, parquetFilePath, false); + } + checkedKeyMaterialInternalStorage = true; + } + + KeyMaterial keyMaterial; + if (keyMetadata.keyMaterialStoredInternally()) { + // Internal key material storage: key material is inside key metadata + keyMaterial = keyMetadata.getKeyMaterial(); + } + else { + // External key material storage: key metadata contains a reference to a key in the material store + String keyIDinFile = keyMetadata.getKeyReference(); + String keyMaterialString = keyMaterialStore.getKeyMaterial(keyIDinFile); + if (null == keyMaterialString) { + throw new ParquetCryptoRuntimeException("Null key material for keyIDinFile: " + keyIDinFile); + } + keyMaterial = KeyMaterial.parse(keyMaterialString); + } + + return getDEKandMasterID(keyMaterial).getDataKey(); + } + + KeyWithMasterID getDEKandMasterID(KeyMaterial keyMaterial) + { + if (null == kmsClientAndDetails) { + kmsClientAndDetails = getKmsClientFromConfigOrKeyMaterial(keyMaterial); + } + + boolean doubleWrapping = keyMaterial.isDoubleWrapped(); + String masterKeyID = keyMaterial.getMasterKeyID(); + String encodedWrappedDEK = keyMaterial.getWrappedDEK(); + + byte[] dataKey; + TrinoKmsClient kmsClient = kmsClientAndDetails.getKmsClient(); + if (!doubleWrapping) { + dataKey = kmsClient.unwrapKey(encodedWrappedDEK, masterKeyID); + } + else { + // Get KEK + String encodedKekID = keyMaterial.getKekID(); + String encodedWrappedKEK = keyMaterial.getWrappedKEK(); + + byte[] kekBytes = kekPerKekID.computeIfAbsent(encodedKekID, + (k) -> kmsClient.unwrapKey(encodedWrappedKEK, masterKeyID)); + + if (null == kekBytes) { + throw new ParquetCryptoRuntimeException("Null KEK, after unwrapping in KMS with master key " + masterKeyID); + } + + // Decrypt the data key + byte[] aad = Base64.getDecoder().decode(encodedKekID); + dataKey = TrinoKeyToolkit.decryptKeyLocally(encodedWrappedDEK, kekBytes, aad); + } + + return new KeyWithMasterID(dataKey, masterKeyID); + } + + TrinoKeyToolkit.TrinoKmsClientAndDetails getKmsClientFromConfigOrKeyMaterial(KeyMaterial keyMaterial) + { + String kmsInstanceID = this.parquetReaderOptions.getEncryptionKmsInstanceId(); + if (Strings.isNullOrEmpty(kmsInstanceID)) { + kmsInstanceID = keyMaterial.getKmsInstanceID(); + if (null == kmsInstanceID) { + throw new ParquetCryptoRuntimeException("KMS instance ID is missing both in properties and file key material"); + } + } + + String kmsInstanceURL = this.parquetReaderOptions.getEncryptionKmsInstanceUrl(); + if (Strings.isNullOrEmpty(kmsInstanceURL)) { + kmsInstanceURL = keyMaterial.getKmsInstanceURL(); + if (null == kmsInstanceURL) { + throw new ParquetCryptoRuntimeException("KMS instance URL is missing both in properties and file key material"); + } + } + + TrinoKmsClient kmsClient = TrinoKeyToolkit.getKmsClient(kmsInstanceID, kmsInstanceURL, this.parquetReaderOptions, accessToken, cacheEntryLifetime); + if (null == kmsClient) { + throw new ParquetCryptoRuntimeException("KMSClient was not successfully created for reading encrypted data."); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("File unwrapper - KmsClient: {}; InstanceId: {}; InstanceURL: {}", kmsClient, kmsInstanceID, kmsInstanceURL); + } + return new TrinoKeyToolkit.TrinoKmsClientAndDetails(kmsClient, kmsInstanceID, kmsInstanceURL); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java new file mode 100644 index 000000000000..4c178c0bd8fe --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Map; + +public class TrinoHadoopFSKeyMaterialStore +{ + public static final String KEY_MATERIAL_FILE_PREFIX = "_KEY_MATERIAL_FOR_"; + public static final String TEMP_FILE_PREFIX = "_TMP"; + public static final String KEY_MATERIAL_FILE_SUFFFIX = ".json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private TrinoFileSystem trinoFileSystem; + private Map keyMaterialMap; + private Location keyMaterialFile; + + TrinoHadoopFSKeyMaterialStore(TrinoFileSystem trinoFileSystem, Location parquetFilePath, boolean tempStore) + { + this.trinoFileSystem = trinoFileSystem; + String fullPrefix = (tempStore ? TEMP_FILE_PREFIX : ""); + fullPrefix += KEY_MATERIAL_FILE_PREFIX; + keyMaterialFile = parquetFilePath.parentDirectory().appendPath( + fullPrefix + parquetFilePath.fileName() + KEY_MATERIAL_FILE_SUFFFIX); + } + + public String getKeyMaterial(String keyIDInFile) + throws ParquetCryptoRuntimeException + { + if (null == keyMaterialMap) { + loadKeyMaterialMap(); + } + return keyMaterialMap.get(keyIDInFile); + } + + private void loadKeyMaterialMap() + { + TrinoInputFile inputfile = trinoFileSystem.newInputFile(keyMaterialFile); + try (TrinoInputStream keyMaterialStream = inputfile.newStream()) { + JsonNode keyMaterialJson = objectMapper.readTree(keyMaterialStream); + keyMaterialMap = objectMapper.readValue(keyMaterialJson.traverse(), + new TypeReference>() {}); + } + catch (FileNotFoundException e) { + throw new ParquetCryptoRuntimeException("External key material not found at " + keyMaterialFile, e); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to get key material from " + keyMaterialFile, e); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java new file mode 100644 index 000000000000..eb05702732ba --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKeyToolkit.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.AesGcmDecryptor; +import io.trino.parquet.crypto.AesMode; +import io.trino.parquet.crypto.ModuleCipherFactory; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoCryptoConfigurationUtil; + +import java.lang.reflect.InvocationTargetException; +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +public class TrinoKeyToolkit +{ + public static final long CACHE_LIFETIME_DEFAULT_SECONDS = 10 * 60; // 10 minutes + + // KMS client two level cache: token -> KMSInstanceId -> KmsClient + static final TwoLevelCacheWithExpiration KMS_CLIENT_CACHE_PER_TOKEN = + KmsClientCache.INSTANCE.getCache(); + + // KEK two level cache for unwrapping: token -> KEK_ID -> KEK bytes + static final TwoLevelCacheWithExpiration KEK_READ_CACHE_PER_TOKEN = + KEKReadCache.INSTANCE.getCache(); + + private TrinoKeyToolkit() + { + } + + private enum KmsClientCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + private enum KEKReadCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + static String formatTokenForLog(String accessToken) + { + int maxTokenDisplayLength = 5; + if (accessToken.length() <= maxTokenDisplayLength) { + return accessToken; + } + return accessToken.substring(accessToken.length() - maxTokenDisplayLength); + } + + static class KeyWithMasterID + { + private final byte[] keyBytes; + private final String masterID; + + KeyWithMasterID(byte[] keyBytes, String masterID) + { + this.keyBytes = keyBytes; + this.masterID = masterID; + } + + byte[] getDataKey() + { + return keyBytes; + } + + String getMasterID() + { + return masterID; + } + } + + static class KeyEncryptionKey + { + private final byte[] kekBytes; + private final byte[] kekID; + private String encodedKekID; + private final String encodedWrappedKEK; + + KeyEncryptionKey(byte[] kekBytes, byte[] kekID, String encodedWrappedKEK) + { + this.kekBytes = kekBytes; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + } + + byte[] getBytes() + { + return kekBytes; + } + + byte[] getID() + { + return kekID; + } + + String getEncodedID() + { + if (null == encodedKekID) { + encodedKekID = Base64.getEncoder().encodeToString(kekID); + } + return encodedKekID; + } + + String getEncodedWrappedKEK() + { + return encodedWrappedKEK; + } + } + + /** + * Decrypts encrypted key with "masterKey", using AES-GCM and the "aad" + * + * @param encodedEncryptedKey base64 encoded encrypted key + * @param masterKeyBytes encryption key + * @param aad additional authenticated data + * @return decrypted key + */ + public static byte[] decryptKeyLocally(String encodedEncryptedKey, byte[] masterKeyBytes, byte[] aad) + { + byte[] encryptedKey = Base64.getDecoder().decode(encodedEncryptedKey); + + AesGcmDecryptor keyDecryptor; + + keyDecryptor = (AesGcmDecryptor) ModuleCipherFactory.getDecryptor(AesMode.GCM, masterKeyBytes); + + return keyDecryptor.decrypt(encryptedKey, 0, encryptedKey.length, aad); + } + + static TrinoKmsClient getKmsClient(String kmsInstanceID, String kmsInstanceURL, ParquetReaderOptions trinoParquetCryptoConfig, + String accessToken, long cacheEntryLifetime) + { + ConcurrentMap kmsClientPerKmsInstanceCache = + KMS_CLIENT_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + TrinoKmsClient kmsClient = + kmsClientPerKmsInstanceCache.computeIfAbsent(kmsInstanceID, + (k) -> createAndInitKmsClient(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken)); + + return kmsClient; + } + + private static TrinoKmsClient createAndInitKmsClient(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, + String kmsInstanceURL, String accessToken) + { + Class> kmsClientClass = null; + TrinoKmsClient kmsClient; + + try { + kmsClientClass = TrinoCryptoConfigurationUtil.getClassFromConfig(trinoParquetCryptoConfig.getEncryptionKmsClientClass(), + TrinoKmsClient.class); + + if (null == kmsClientClass) { + throw new ParquetCryptoRuntimeException("Could not find class " + trinoParquetCryptoConfig.getEncryptionKmsClientClass()); + } + kmsClient = (TrinoKmsClient) kmsClientClass.getConstructor().newInstance(); + } + catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + throw new ParquetCryptoRuntimeException("Could not instantiate KmsClient class: " + + kmsClientClass, e); + } + + kmsClient.initialize(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken); + + return kmsClient; + } + + static class TrinoKmsClientAndDetails + { + public TrinoKmsClient getKmsClient() + { + return kmsClient; + } + + private TrinoKmsClient kmsClient; + private String kmsInstanceID; + private String kmsInstanceURL; + + public TrinoKmsClientAndDetails(TrinoKmsClient kmsClient, String kmsInstanceID, String kmsInstanceURL) + { + this.kmsClient = kmsClient; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + } + + public String getKmsInstanceID() + { + return kmsInstanceID; + } + + public String getKmsInstanceURL() + { + return kmsInstanceURL; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java new file mode 100644 index 000000000000..6ca6cb0cb53e --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoKmsClient.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.KeyAccessDeniedException; + +public interface TrinoKmsClient +{ + String KEY_ACCESS_TOKEN_DEFAULT = "DEFAULT"; + + void initialize(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, String kmsInstanceURL, String accessToken) + throws KeyAccessDeniedException; + + String wrapKey(byte[] keyBytes, String masterKeyIdentifier) + throws KeyAccessDeniedException; + + byte[] unwrapKey(String wrappedKey, String masterKeyIdentifier) + throws KeyAccessDeniedException; +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java new file mode 100644 index 000000000000..8eb61c18c0e8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TrinoDecryptionPropertiesFactory; + +public class TrinoPropertiesDrivenCryptoFactory + implements TrinoDecryptionPropertiesFactory +{ + private static final Logger LOG = Logger.get(TrinoPropertiesDrivenCryptoFactory.class); + + @Override + public FileDecryptionProperties getFileDecryptionProperties(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException + { + DecryptionKeyRetriever keyRetriever = new TrinoFileKeyUnwrapper(parquetReaderOptions, filePath, trinoFileSystem); + + if (LOG.isDebugEnabled()) { + LOG.debug("File decryption properties for {}", filePath); + } + + return FileDecryptionProperties.builder() + .withKeyRetriever(keyRetriever) + .withPlaintextFilesAllowed() + .build(); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java new file mode 100644 index 000000000000..ca2e7d2d356d --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/keytools/TwoLevelCacheWithExpiration.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto.keytools; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * Concurrent two-level cache with expiration of internal caches according to token lifetime. + * External cache is per token, internal is per String key. + * + * @param Value + */ +class TwoLevelCacheWithExpiration +{ + private final ConcurrentMap>> cache; + private volatile long lastCacheCleanupTimestamp; + + TwoLevelCacheWithExpiration() + { + this.cache = new ConcurrentHashMap<>(); + this.lastCacheCleanupTimestamp = System.currentTimeMillis(); + } + + ConcurrentMap getOrCreateInternalCache(String accessToken, long cacheEntryLifetime) + { + ExpiringCacheEntry> externalCacheEntry = + cache.compute(accessToken, (token, cacheEntry) -> { + if ((null == cacheEntry) || cacheEntry.isExpired()) { + return new ExpiringCacheEntry<>(new ConcurrentHashMap(), cacheEntryLifetime); + } + else { + return cacheEntry; + } + }); + return externalCacheEntry.getCachedItem(); + } + + void removeCacheEntriesForToken(String accessToken) + { + cache.remove(accessToken); + } + + void removeCacheEntriesForAllTokens() + { + cache.clear(); + } + + public void checkCacheForExpiredTokens(long cacheCleanupPeriod) + { + long now = System.currentTimeMillis(); + + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + synchronized (cache) { + if (now > (lastCacheCleanupTimestamp + cacheCleanupPeriod)) { + removeExpiredEntriesFromCache(); + lastCacheCleanupTimestamp = now + cacheCleanupPeriod; + } + } + } + } + + public void removeExpiredEntriesFromCache() + { + cache.values().removeIf(cacheEntry -> cacheEntry.isExpired()); + } + + public void remove(String accessToken) + { + cache.remove(accessToken); + } + + public void clear() + { + cache.clear(); + } + + static class ExpiringCacheEntry + { + private final long expirationTimestamp; + private final E cachedItem; + + private ExpiringCacheEntry(E cachedItem, long expirationIntervalMillis) + { + this.expirationTimestamp = System.currentTimeMillis() + expirationIntervalMillis; + this.cachedItem = cachedItem; + } + + private boolean isExpired() + { + final long now = System.currentTimeMillis(); + return (now > expirationTimestamp); + } + + private E getCachedItem() + { + return cachedItem; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java index 43defc21b834..1a955515fe50 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java @@ -15,7 +15,7 @@ import java.util.List; -public record BlockMetadata(long rowCount, List columns) +public record BlockMetadata(long rowCount, long totalByteSize, short ordinal, List columns) { public long getStartingPos() { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java index 381260829869..0c9c85c95aee 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java @@ -23,6 +23,9 @@ import java.util.Set; +import static io.trino.parquet.ParquetEncoding.PLAIN_DICTIONARY; +import static io.trino.parquet.ParquetEncoding.RLE_DICTIONARY; + public abstract class ColumnChunkMetadata { protected int rowGroupOrdinal = -1; @@ -200,4 +203,16 @@ public String toString() decryptIfNeeded(); return "ColumnMetaData{" + properties.toString() + ", " + getFirstDataPageOffset() + "}"; } + + public boolean hasDictionaryPage() + { + EncodingStats stats = getEncodingStats(); + if (stats != null) { + // ensure there is a dictionary page and that it is used to encode data pages + return stats.hasDictionaryPages() && stats.hasDictionaryEncodedPages(); + } + + Set encodings = getEncodings(); + return (encodings.contains(PLAIN_DICTIONARY) || encodings.contains(RLE_DICTIONARY)); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java index 6901bb23a4e6..3293e980e719 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; @@ -230,9 +231,11 @@ private static Map> getStatistics(PrunedBlockMet ImmutableMap.Builder> statistics = ImmutableMap.builderWithExpectedSize(descriptorsByPath.size()); for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - Statistics> columnStatistics = columnMetaData.getStatistics(); - if (columnStatistics != null) { - statistics.put(descriptor, columnStatistics); + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + if (columnStatistics != null) { + statistics.put(descriptor, columnStatistics); + } } } return statistics.buildOrThrow(); @@ -260,18 +263,20 @@ private static boolean dictionaryPredicatesMatch( { for (ColumnDescriptor descriptor : descriptorsByPath.values()) { ColumnChunkMetadata columnMetaData = columnsMetadata.getColumnChunkMetaData(descriptor); - if (!candidateColumns.contains(descriptor)) { - continue; - } - if (isOnlyDictionaryEncodingPages(columnMetaData)) { - Statistics> columnStatistics = columnMetaData.getStatistics(); - boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; - // Early abort, predicate already filters block so no more dictionaries need be read - if (!parquetPredicate.matches(new DictionaryDescriptor( - descriptor, - nullAllowed, - readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { - return false; + if (!HiddenColumnChunkMetaData.isHiddenColumn(columnMetaData)) { + if (!candidateColumns.contains(descriptor)) { + continue; + } + if (isOnlyDictionaryEncodingPages(columnMetaData)) { + Statistics> columnStatistics = columnMetaData.getStatistics(); + boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0; + // Early abort, predicate already filters block so no more dictionaries need be read + if (!parquetPredicate.matches(new DictionaryDescriptor( + descriptor, + nullAllowed, + readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { + return false; + } } } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index fe0635646f98..294cfe0604b2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -15,26 +15,41 @@ import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; +import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.AesGcmEncryptor; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.KeyAccessDeniedException; +import io.trino.parquet.crypto.ModuleCipherFactory.ModuleType; +import io.trino.parquet.crypto.ParquetCryptoRuntimeException; +import io.trino.parquet.crypto.TagVerificationException; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.format.BlockCipher.Decryptor; import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnMetaData; import org.apache.parquet.format.Encoding; +import org.apache.parquet.format.EncryptionWithColumnKey; +import org.apache.parquet.format.FileCryptoMetaData; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.SchemaElement; import org.apache.parquet.format.Statistics; +import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -43,6 +58,7 @@ import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -56,7 +72,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetMetadataConverter.convertEncodingStats; import static io.trino.parquet.ParquetMetadataConverter.fromParquetStatistics; import static io.trino.parquet.ParquetMetadataConverter.getEncoding; @@ -69,6 +87,7 @@ import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static org.apache.parquet.format.Util.readFileCryptoMetaData; import static org.apache.parquet.format.Util.readFileMetaData; public final class MetadataReader @@ -76,13 +95,14 @@ public final class MetadataReader private static final Logger log = Logger.get(MetadataReader.class); private static final Slice MAGIC = Slices.utf8Slice("PAR1"); + private static final Slice EMAGIC = Slices.utf8Slice("PARE"); private static final int POST_SCRIPT_SIZE = Integer.BYTES + MAGIC.length(); // Typical 1GB files produced by Trino were found to have footer size between 30-40KB private static final int EXPECTED_FOOTER_SIZE = 48 * 1024; private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation, Optional fileDecryptor) throws IOException { // Parquet File Layout: @@ -93,7 +113,9 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< // 4 bytes: MetadataLength // MAGIC - validateParquet(dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE, dataSource.getId(), "%s is not a valid Parquet File", dataSource.getId()); + validateParquet((dataSource.getEstimatedSize() >= MAGIC.length() + POST_SCRIPT_SIZE) || + (dataSource.getEstimatedSize() >= EMAGIC.length() + POST_SCRIPT_SIZE), dataSource.getId(), + "%s is not a valid Parquet File", dataSource.getId()); // Read the tail of the file long estimatedFileSize = dataSource.getEstimatedSize(); @@ -101,8 +123,10 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< Slice buffer = dataSource.readTail(toIntExact(expectedReadSize)); Slice magic = buffer.slice(buffer.length() - MAGIC.length(), MAGIC.length()); - validateParquet(MAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + validateParquet(MAGIC.equals(magic) || EMAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s or %s got: %s", MAGIC.toStringUtf8(), EMAGIC.toStringUtf8(), magic.toStringUtf8()); + boolean encryptedFooterMode = EMAGIC.equals(magic); + checkArgument(!encryptedFooterMode || !(fileDecryptor.isEmpty() || fileDecryptor.get().getDecryptionProperties() == null), "fileDecryptionProperties cannot be null when encryptedFooterMode is true"); int metadataLength = buffer.getInt(buffer.length() - POST_SCRIPT_SIZE); long metadataIndex = estimatedFileSize - POST_SCRIPT_SIZE - metadataLength; validateParquet( @@ -118,13 +142,44 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< } InputStream metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); - FileMetaData fileMetaData = readFileMetaData(metadataStream); - ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, dataSource.getId()); + Decryptor footerDecryptor = null; + byte[] aad = null; + + if (encryptedFooterMode) { + FileCryptoMetaData fileCryptoMetaData = readFileCryptoMetaData(metadataStream); + fileDecryptor.get().setFileCryptoMetaData(fileCryptoMetaData.getEncryption_algorithm(), true, fileCryptoMetaData.getKey_metadata()); + footerDecryptor = fileDecryptor.get().fetchFooterDecryptor(); + aad = AesCipher.createFooterAAD(fileDecryptor.get().getFileAAD()); + } + FileMetaData fileMetaData = readFileMetaData(metadataStream, footerDecryptor, aad); + if (!encryptedFooterMode && fileDecryptor.isPresent()) { + if (!fileMetaData.isSetEncryption_algorithm()) { // Plaintext file + fileDecryptor.get().setPlaintextFile(); + // Done to detect files that were not encrypted by mistake + if (!fileDecryptor.get().plaintextFilesAllowed()) { + throw new ParquetCryptoRuntimeException("Applying decryptor on plaintext file"); + } + } + else { // Encrypted file with plaintext footer + // if no fileDecryptor, can still read plaintext columns + fileDecryptor.get().setFileCryptoMetaData(fileMetaData.getEncryption_algorithm(), false, + fileMetaData.getFooter_signing_key_metadata()); + if (fileDecryptor.get().checkFooterIntegrity()) { + verifyFooterIntegrity(metadataStream, fileDecryptor.get(), metadataLength); + } + } + } + ParquetDataSourceId id = dataSource.getId(); + ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, id, fileDecryptor, encryptedFooterMode); + validateFileMetadata(id, parquetMetadata.getFileMetaData(), parquetWriteValidation); validateFileMetadata(dataSource.getId(), parquetMetadata.getFileMetaData(), parquetWriteValidation); return parquetMetadata; } - public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, ParquetDataSourceId dataSourceId) + public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, + ParquetDataSourceId dataSourceId, + Optional fileDecryptor, + boolean encryptedFooterMode) throws ParquetCorruptionException { List schema = fileMetaData.getSchema(); @@ -138,37 +193,79 @@ public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, P List columns = rowGroup.getColumns(); validateParquet(!columns.isEmpty(), dataSourceId, "No columns in row group: %s", rowGroup); String filePath = columns.get(0).getFile_path(); + int columnOrdinal = -1; ImmutableList.Builder columnMetadataBuilder = ImmutableList.builderWithExpectedSize(columns.size()); for (ColumnChunk columnChunk : columns) { + columnOrdinal++; validateParquet( (filePath == null && columnChunk.getFile_path() == null) || (filePath != null && filePath.equals(columnChunk.getFile_path())), dataSourceId, "all column chunks of the same row group must be in the same file"); + ColumnCryptoMetaData cryptoMetaData = columnChunk.getCrypto_metadata(); ColumnMetaData metaData = columnChunk.meta_data; - String[] path = metaData.path_in_schema.stream() - .map(value -> value.toLowerCase(Locale.ENGLISH)) - .toArray(String[]::new); - ColumnPath columnPath = ColumnPath.get(path); - PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); - ColumnChunkMetadata column = ColumnChunkMetadata.get( - columnPath, - primitiveType, - CompressionCodecName.fromParquet(metaData.codec), - convertEncodingStats(metaData.encoding_stats), - readEncodings(metaData.encodings), - readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), - metaData.data_page_offset, - metaData.dictionary_page_offset, - metaData.num_values, - metaData.total_compressed_size, - metaData.total_uncompressed_size); - column.setColumnIndexReference(toColumnIndexReference(columnChunk)); - column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); - column.setBloomFilterOffset(metaData.bloom_filter_offset); - columnMetadataBuilder.add(column); + ColumnPath columnPath = null; + boolean encryptedMetadata = false; + if (cryptoMetaData == null) { + columnPath = getPath(metaData); + if (fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile()) { + // mark this column as plaintext in encrypted file decryptor + fileDecryptor.get().setColumnCryptoMetadata(columnPath, false, false, (byte[]) null, columnOrdinal); + } + } + else { // Encrypted column + if (cryptoMetaData.isSetENCRYPTION_WITH_FOOTER_KEY()) { // Column encrypted with footer key + if (!encryptedFooterMode) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key in file with plaintext footer"); + } + if (null == metaData) { + throw new ParquetCryptoRuntimeException("ColumnMetaData not set in Encryption with Footer key"); + } + if (fileDecryptor.isEmpty()) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key: No keys available"); + } + columnPath = getPath(metaData); + fileDecryptor.get().setColumnCryptoMetadata(columnPath, true, true, (byte[]) null, columnOrdinal); + } + else { // Column encrypted with column key + encryptedMetadata = true; + } + } + try { + if (encryptedMetadata) { + // TODO: We decrypted data before filter projection. This could send unnecessary traffic to KMS. + // In parquet-mr, it uses lazy decyrption but that required to change ColumnChunkMetadata. We will improve it alter. + metaData = decryptMetadata(rowGroup, cryptoMetaData, columnChunk, fileDecryptor.get(), columnOrdinal); + columnPath = getPath(metaData); + } + PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); + ColumnChunkMetadata column = ColumnChunkMetadata.get( + columnPath, + primitiveType, + CompressionCodecName.fromParquet(metaData.codec), + convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + column.setColumnIndexReference(toColumnIndexReference(columnChunk)); + column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); + column.setBloomFilterOffset(metaData.bloom_filter_offset); + + if (rowGroup.isSetOrdinal()) { + column.setRowGroupOrdinal(rowGroup.getOrdinal()); + } + columnMetadataBuilder.add(column); + } + catch (KeyAccessDeniedException e) { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(columnPath, filePath); + columnMetadataBuilder.add(column); + } } - blocks.add(new BlockMetadata(rowGroup.getNum_rows(), columnMetadataBuilder.build())); + blocks.add(new BlockMetadata(rowGroup.getNum_rows(), rowGroup.getTotal_byte_size(), rowGroup.getOrdinal(), columnMetadataBuilder.build())); } } @@ -274,6 +371,25 @@ public static org.apache.parquet.column.statistics.Statistics> readStats(Optio return columnStatistics; } + /** + * If a column is encrypted and user doesn't provide correct key to decrypt, that column is hidden to current request. + * This method find out the first non-hidden column. + * + * @param block BlockMetaData + * @return first non hidden column id. + */ + public static Integer findFirstNonHiddenColumnId(BlockMetadata block) + { + List columns = block.columns(); + for (int i = 0; i < columns.size(); i++) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(columns.get(i))) { + return i; + } + } + // all columns are hidden (encrypted but not accessible to current user) + return null; + } + private static boolean isStringType(PrimitiveType type) { if (type.getLogicalTypeAnnotation() == null) { @@ -373,4 +489,75 @@ private static void validateFileMetadata(ParquetDataSourceId dataSourceId, FileM Optional.ofNullable(fileMetaData.getKeyValueMetaData().get("writer.time.zone"))); writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); } + + private static ColumnMetaData decryptMetadata(RowGroup rowGroup, ColumnCryptoMetaData cryptoMetaData, ColumnChunk columnChunk, InternalFileDecryptor fileDecryptor, int columnOrdinal) + { + EncryptionWithColumnKey columnKeyStruct = cryptoMetaData.getENCRYPTION_WITH_COLUMN_KEY(); + List pathList = columnKeyStruct.getPath_in_schema().stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()); + + byte[] columnKeyMetadata = columnKeyStruct.getKey_metadata(); + ColumnPath columnPath = ColumnPath.get(pathList.toArray(new String[pathList.size()])); + byte[] encryptedMetadataBuffer = columnChunk.getEncrypted_column_metadata(); + + // Decrypt the ColumnMetaData + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.setColumnCryptoMetadata(columnPath, true, false, columnKeyMetadata, columnOrdinal); + ByteArrayInputStream tempInputStream = new ByteArrayInputStream(encryptedMetadataBuffer); + byte[] columnMetaDataAAD = AesCipher.createModuleAAD(fileDecryptor.getFileAAD(), ModuleType.ColumnMetaData, rowGroup.ordinal, columnOrdinal, -1); + try { + return Util.readColumnMetaData(tempInputStream, columnDecryptionSetup.getMetaDataDecryptor(), columnMetaDataAAD); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException(columnPath + ". Failed to decrypt column metadata", e); + } + } + + /*public static ColumnChunkMetadata buildColumnChunkMetaData(Optional fileCreatedBy, ColumnMetaData metaData, ColumnPath columnPath, PrimitiveType type) + { + return ColumnChunkMetadata.get( + columnPath, + type, + CompressionCodecName.fromParquet(metaData.codec), + PARQUET_METADATA_CONVERTER.convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(fileCreatedBy, Optional.ofNullable(metaData.statistics), type), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + }*/ + + private static ColumnPath getPath(ColumnMetaData metaData) + { + String[] path = metaData.path_in_schema.stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)) + .toArray(String[]::new); + return ColumnPath.get(path); + } + + private static void verifyFooterIntegrity(InputStream metadataStream, InternalFileDecryptor fileDecryptor, int combinedFooterLength) + throws IOException + { + byte[] nonce = new byte[AesCipher.NONCE_LENGTH]; + metadataStream.read(nonce); + byte[] gcmTag = new byte[AesCipher.GCM_TAG_LENGTH]; + metadataStream.read(gcmTag); + + AesGcmEncryptor footerSigner = fileDecryptor.createSignedFooterEncryptor(); + int footerSignatureLength = AesCipher.NONCE_LENGTH + AesCipher.GCM_TAG_LENGTH; + byte[] serializedFooter = new byte[combinedFooterLength - footerSignatureLength]; + + //InputStream doesn't implement reset(). Here is to workaround + ((BasicSliceInput) metadataStream).setPosition(0); + metadataStream.read(serializedFooter, 0, serializedFooter.length); + + byte[] signedFooterAAD = AesCipher.createFooterAAD(fileDecryptor.getFileAAD()); + byte[] encryptedFooterBytes = footerSigner.encrypt(false, serializedFooter, nonce, signedFooterAAD); + byte[] calculatedTag = new byte[AesCipher.GCM_TAG_LENGTH]; + System.arraycopy(encryptedFooterBytes, encryptedFooterBytes.length - AesCipher.GCM_TAG_LENGTH, calculatedTag, 0, AesCipher.GCM_TAG_LENGTH); + if (!Arrays.equals(gcmTag, calculatedTag)) { + throw new TagVerificationException("Signature mismatch in plaintext footer"); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index d8ec35c52fbe..799d4b111654 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -16,17 +16,24 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; import io.trino.parquet.DataPage; import io.trino.parquet.DataPageV1; import io.trino.parquet.DataPageV2; import io.trino.parquet.DictionaryPage; import io.trino.parquet.Page; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.CompressionCodec; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -35,6 +42,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetCompressionUtils.decompress; import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages; import static java.util.Objects.requireNonNull; @@ -49,6 +57,10 @@ public final class PageReader private boolean dictionaryAlreadyRead; private int dataPageReadCount; + private int pageIndex; + private final BlockCipher.Decryptor blockDecryptor; + private byte[] dataPageAAD; + private byte[] dictionaryPageAAD; public static PageReader createPageReader( ParquetDataSourceId dataSourceId, @@ -56,7 +68,8 @@ public static PageReader createPageReader( ColumnChunkMetadata metadata, ColumnDescriptor columnDescriptor, @Nullable OffsetIndex offsetIndex, - Optional fileCreatedBy) + Optional fileCreatedBy, + Optional fileDecryptor) { // Parquet schema may specify a column definition as OPTIONAL even though there are no nulls in the actual data. // Row-group column statistics can be used to identify such cases and switch to faster non-nullable read @@ -64,20 +77,36 @@ public static PageReader createPageReader( Statistics> columnStatistics = metadata.getStatistics(); boolean hasNoNulls = columnStatistics != null && columnStatistics.getNumNulls() == 0; boolean hasOnlyDictionaryEncodedPages = isOnlyDictionaryEncodingPages(metadata); + byte[] fileAad = null; + BlockCipher.Decryptor dataDecryptor = null; + int columnOrdinal = -1; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + fileAad = fileDecryptor.get().getFileAAD(); + dataDecryptor = columnDecryptionSetup.getDataDecryptor(); + columnOrdinal = columnDecryptionSetup.getOrdinal(); + } ParquetColumnChunkIterator compressedPages = new ParquetColumnChunkIterator( dataSourceId, fileCreatedBy, columnDescriptor, metadata, columnChunk, - offsetIndex); + offsetIndex, + fileDecryptor, + columnOrdinal); return new PageReader( dataSourceId, metadata.getCodec().getParquetCompressionCodec(), compressedPages, hasOnlyDictionaryEncodedPages, - hasNoNulls); + hasNoNulls, + dataDecryptor, + fileAad, + metadata.getRowGroupOrdinal(), + columnOrdinal); } @VisibleForTesting @@ -86,13 +115,22 @@ public PageReader( CompressionCodec codec, Iterator extends Page> compressedPages, boolean hasOnlyDictionaryEncodedPages, - boolean hasNoNulls) + boolean hasNoNulls, + BlockCipher.Decryptor blockDecryptor, + byte[] fileAAD, + int rowGroupOrdinal, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.codec = codec; this.compressedPages = Iterators.peekingIterator(compressedPages); this.hasOnlyDictionaryEncodedPages = hasOnlyDictionaryEncodedPages; this.hasNoNulls = hasNoNulls; + this.blockDecryptor = blockDecryptor; + if (null != blockDecryptor) { + dataPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DataPage, rowGroupOrdinal, columnOrdinal, 0); + dictionaryPageAAD = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DictionaryPage, rowGroupOrdinal, columnOrdinal, -1); + } } public boolean hasNoNulls() @@ -114,18 +152,23 @@ public DataPage readPage() checkState(compressedPage instanceof DataPage, "Found page %s instead of a DataPage", compressedPage); dataPageReadCount++; try { + if (null != blockDecryptor) { + AesCipher.quickUpdatePageAAD(dataPageAAD, ((DataPage) compressedPage).getPageIndex()); + } + Slice slice = decryptSliceIfNeeded(compressedPage.getSlice(), dataPageAAD); if (compressedPage instanceof DataPageV1 dataPageV1) { if (!arePagesCompressed()) { return dataPageV1; } return new DataPageV1( - decompress(dataSourceId, codec, dataPageV1.getSlice(), dataPageV1.getUncompressedSize()), + decompress(dataSourceId, codec, slice, dataPageV1.getUncompressedSize()), dataPageV1.getValueCount(), dataPageV1.getUncompressedSize(), dataPageV1.getFirstRowIndex(), dataPageV1.getRepetitionLevelEncoding(), dataPageV1.getDefinitionLevelEncoding(), - dataPageV1.getValueEncoding()); + dataPageV1.getValueEncoding(), + dataPageV1.getPageIndex()); } DataPageV2 dataPageV2 = (DataPageV2) compressedPage; if (!dataPageV2.isCompressed()) { @@ -141,11 +184,12 @@ public DataPage readPage() dataPageV2.getRepetitionLevels(), dataPageV2.getDefinitionLevels(), dataPageV2.getDataEncoding(), - decompress(dataSourceId, codec, dataPageV2.getSlice(), uncompressedSize), + decompress(dataSourceId, codec, slice, uncompressedSize), dataPageV2.getUncompressedSize(), dataPageV2.getFirstRowIndex(), dataPageV2.getStatistics(), - false); + false, + dataPageV2.getPageIndex()); } catch (IOException e) { throw new RuntimeException("Could not decompress page", e); @@ -162,8 +206,9 @@ public DictionaryPage readDictionaryPage() } try { DictionaryPage compressedDictionaryPage = (DictionaryPage) compressedPages.next(); + Slice slice = decryptSliceIfNeeded(compressedDictionaryPage.getSlice(), dictionaryPageAAD); return new DictionaryPage( - decompress(dataSourceId, codec, compressedDictionaryPage.getSlice(), compressedDictionaryPage.getUncompressedSize()), + decompress(dataSourceId, codec, slice, compressedDictionaryPage.getUncompressedSize()), compressedDictionaryPage.getDictionarySize(), compressedDictionaryPage.getEncoding()); } @@ -199,4 +244,14 @@ private void verifyDictionaryPageRead() { checkArgument(dictionaryAlreadyRead, "Dictionary has to be read first"); } + + private Slice decryptSliceIfNeeded(Slice slice, byte[] aad) + throws IOException + { + if (blockDecryptor == null) { + return slice; + } + byte[] plainText = blockDecryptor.decrypt(slice.getBytes(), aad); + return wrappedBuffer(plainText); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java index 235c1b2d3d76..720d5f16151f 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java @@ -19,15 +19,21 @@ import io.trino.parquet.Page; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipher; +import io.trino.parquet.crypto.InternalColumnDecryptionSetup; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.crypto.ModuleCipherFactory; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.PageHeader; import org.apache.parquet.format.Util; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -51,6 +57,9 @@ public final class ParquetColumnChunkIterator private long valueCount; private int dataPageCount; + private Optional fileDecryptor; + private int columnOrdinal; + private boolean dictionaryWasRead; public ParquetColumnChunkIterator( ParquetDataSourceId dataSourceId, @@ -58,7 +67,9 @@ public ParquetColumnChunkIterator( ColumnDescriptor descriptor, ColumnChunkMetadata metadata, ChunkedInputStream input, - @Nullable OffsetIndex offsetIndex) + @Nullable OffsetIndex offsetIndex, + Optional fileDecryptor, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -66,6 +77,8 @@ public ParquetColumnChunkIterator( this.metadata = requireNonNull(metadata, "metadata is null"); this.input = requireNonNull(input, "input is null"); this.offsetIndex = offsetIndex; + this.fileDecryptor = fileDecryptor; + this.columnOrdinal = columnOrdinal; } @Override @@ -79,8 +92,32 @@ public Page next() { checkState(hasNext(), "No more data left to read in column (%s), metadata (%s), valueCount %s, dataPageCount %s", descriptor, metadata, valueCount, dataPageCount); + byte[] dataPageHeaderAAD = null; + BlockCipher.Decryptor headerBlockDecryptor = null; + InternalColumnDecryptionSetup columnDecryptionSetup = null; + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(descriptor.getPath()); + columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + headerBlockDecryptor = columnDecryptionSetup.getMetaDataDecryptor(); + if (null != headerBlockDecryptor) { + dataPageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DataPageHeader, metadata.getRowGroupOrdinal(), columnOrdinal, dataPageCount); + } + } try { - PageHeader pageHeader = readPageHeader(); + byte[] pageHeaderAAD = dataPageHeaderAAD; + if (null != headerBlockDecryptor) { + // Important: this verifies file integrity (makes sure dictionary page had not been removed) + if (!(dictionaryWasRead || !metadata.hasDictionaryPage())) { + pageHeaderAAD = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), + ModuleCipherFactory.ModuleType.DictionaryPageHeader, metadata.getRowGroupOrdinal(), + columnOrdinal, -1); + } + else { + AesCipher.quickUpdatePageAAD(dataPageHeaderAAD, dataPageCount); + } + } + PageHeader pageHeader = readPageHeader(headerBlockDecryptor, pageHeaderAAD); int uncompressedPageSize = pageHeader.getUncompressed_page_size(); int compressedPageSize = pageHeader.getCompressed_page_size(); Page result = null; @@ -90,13 +127,14 @@ public Page next() throw new ParquetCorruptionException(dataSourceId, "Column (%s) has a dictionary page after the first position in column chunk", descriptor); } result = readDictionaryPage(pageHeader, pageHeader.getUncompressed_page_size(), pageHeader.getCompressed_page_size()); + dictionaryWasRead = true; break; case DATA_PAGE: - result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; case DATA_PAGE_V2: - result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; default: @@ -110,10 +148,10 @@ public Page next() } } - private PageHeader readPageHeader() + private PageHeader readPageHeader(BlockCipher.Decryptor headerBlockDecryptor, byte[] pageHeaderAAD) throws IOException { - return Util.readPageHeader(input); + return Util.readPageHeader(input, headerBlockDecryptor, pageHeaderAAD); } private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoFar) @@ -139,7 +177,8 @@ private DataPageV1 readDataPageV1( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeader dataHeaderV1 = pageHeader.getData_page_header(); @@ -151,14 +190,16 @@ private DataPageV1 readDataPageV1( firstRowIndex, getParquetEncoding(Encoding.valueOf(dataHeaderV1.getRepetition_level_encoding().name())), getParquetEncoding(Encoding.valueOf(dataHeaderV1.getDefinition_level_encoding().name())), - getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name()))); + getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name())), + pageIndex); } private DataPageV2 readDataPageV2( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeaderV2 dataHeaderV2 = pageHeader.getData_page_header_v2(); @@ -178,7 +219,8 @@ private DataPageV2 readDataPageV2( fileCreatedBy, Optional.ofNullable(dataHeaderV2.getStatistics()), descriptor.getPrimitiveType()), - dataHeaderV2.isIs_compressed()); + dataHeaderV2.isIs_compressed(), + pageIndex); } private static OptionalLong getFirstRowIndex(int pageIndex, OffsetIndex offsetIndex) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 0ad000ccd420..128375e5a32e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -30,6 +30,9 @@ import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.InternalFileDecryptor; +import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -129,6 +132,7 @@ public class ParquetReader private final Map> codecMetrics; private long columnIndexRowsFiltered = -1; + private final Optional fileDecryptor; public ParquetReader( Optional fileCreatedBy, @@ -140,7 +144,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - Optional writeValidation) + Optional writeValidation, + Optional fileDecryptor) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -156,6 +161,7 @@ public ParquetReader( this.maxBatchSize = options.getMaxReadBlockRowCount(); this.columnReaders = new HashMap<>(); this.maxBytesPerCell = new HashMap<>(); + this.fileDecryptor = fileDecryptor; this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); validateWrite( @@ -264,7 +270,7 @@ public long lastBatchStartRow() return firstRowIndexInGroup + nextRowInGroup - batchSize; } - private int nextBatch() + public int nextBatch() throws IOException { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { @@ -457,9 +463,16 @@ private ColumnChunk readPrimitive(PrimitiveField field) offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath()); } ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup)); - columnReader.setPageReader( - createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy), - Optional.ofNullable(rowRanges)); + if (isEncryptedColumn(fileDecryptor, columnDescriptor)) { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } + else { + columnReader.setPageReader( + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptor), + Optional.ofNullable(rowRanges)); + } } ColumnChunk columnChunk = columnReader.readPrimitive(); @@ -491,6 +504,19 @@ public Metrics getMetrics() return new Metrics(metrics.buildOrThrow()); } + private ColumnChunkMetadata getColumnChunkMetaData(BlockMetadata blockMetaData, ColumnDescriptor columnDescriptor) + throws IOException + { + for (ColumnChunkMetadata metadata : blockMetaData.columns()) { + if (!HiddenColumnChunkMetaData.isHiddenColumn(metadata)) { + if (metadata.getPath().equals(ColumnPath.get(columnDescriptor.getPath()))) { + return metadata; + } + } + } + throw new ParquetCorruptionException(dataSource.getId(), "Metadata is missing for column: %s", columnDescriptor); + } + private void initializeColumnReaders() { for (PrimitiveField field : primitiveFields) { @@ -612,4 +638,10 @@ private void validateWrite(java.util.function.Predicate throw new ParquetCorruptionException(dataSource.getId(), "Write validation failed: " + messageFormat, args); } } + + private boolean isEncryptedColumn(Optional fileDecryptor, ColumnDescriptor columnDescriptor) + { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + return fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile() && fileDecryptor.get().getColumnSetup(columnPath).isEncrypted(); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 651d86040ef5..9eb40a5665e4 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -237,7 +237,7 @@ public void validate(ParquetDataSource input) checkState(validationBuilder.isPresent(), "validation is not enabled"); ParquetWriteValidation writeValidation = validationBuilder.get().build(); try { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation), Optional.empty()); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page @@ -293,7 +293,8 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new RuntimeException(exception); }, Optional.empty(), - Optional.of(writeValidation)); + Optional.of(writeValidation), + Optional.empty()); } private void recordValidation(Consumer task) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java index 9f7918115838..e6cdd9825e77 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java @@ -225,7 +225,7 @@ public void setup() testData.getColumnNames(), testData.getPages()), new ParquetReaderOptions()); - parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); columnNames = columns.stream() .map(TpchColumn::getColumnName) .collect(toImmutableList()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index febdaccf617b..59280c6de102 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -164,6 +164,7 @@ public static ParquetReader createParquetReader( return new RuntimeException(exception); }, Optional.of(parquetPredicate), + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java index fc47c42d8d82..448ef7dc26a8 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java @@ -105,7 +105,7 @@ public int read() throws IOException { ColumnReader columnReader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false, null, null, -1, -1); columnReader.setPageReader(pageReader, Optional.empty()); int rowsRead = 0; while (rowsRead < dataPositions) { @@ -133,7 +133,8 @@ private DataPage createDataPage(ValuesWriter writer, int valuesCount) OptionalLong.empty(), RLE, RLE, - getParquetEncoding(writer.getEncoding())); + getParquetEncoding(writer.getEncoding()), + 0); } protected static void run(Class> clazz) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 6a3fccb1e281..37dde42f5e57 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -564,7 +564,11 @@ else if (dictionaryEncoding == DictionaryEncoding.MIXED) { UNCOMPRESSED, inputPages.iterator(), dictionaryEncoding == DictionaryEncoding.ALL || (dictionaryEncoding == DictionaryEncoding.MIXED && testingPages.size() == 1), - false); + false, + null, + null, + -1, + -1); } private static List createDataPages(List testingPages, ValuesWriter encoder, int maxDef, boolean required) @@ -599,7 +603,8 @@ private static DataPage createDataPage(TestingPage testingPage, ValuesWriter enc valueCount * 4, OptionalLong.of(testingPage.pageRowRange().start()), null, - false); + false, + 0); encoder.reset(); return dataPage; } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 445b61268c33..8b8fe067c88a 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -660,7 +660,8 @@ protected static DataPage createDataPage( OptionalLong.empty(), getParquetEncoding(repetitionWriter.getEncoding()), getParquetEncoding(definitionWriter.getEncoding()), - encoding); + encoding, + 0); } return new DataPageV2( valueCount, @@ -673,7 +674,8 @@ protected static DataPage createDataPage( definitionBytes.length + repetitionBytes.length + valueBytes.length, OptionalLong.empty(), null, - false); + false, + 0); } protected static PageReader getPageReaderMock(List dataPages, @Nullable DictionaryPage dictionaryPage) @@ -699,7 +701,7 @@ protected static PageReader getPageReaderMock(List dataPages, @Nullabl return ((DataPageV2) page).getDataEncoding(); }) .allMatch(encoding -> encoding == PLAIN_DICTIONARY || encoding == RLE_DICTIONARY), - hasNoNulls); + hasNoNulls, null, null, -1, -1); } private DataPage createDataPage(DataPageVersion version, ParquetEncoding encoding, ValuesWriter writer, int valueCount) @@ -713,7 +715,7 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin { Slice slice = Slices.wrappedBuffer(writer.getBytes().toByteArray()); if (version == V1) { - return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding); + return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding, 0); } return new DataPageV2( valueCount, @@ -726,7 +728,8 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin slice.length(), firstRowIndex, null, - false); + false, + 0); } private static ValuesWriter getLevelsWriter(int maxLevel, int valueCount) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java new file mode 100644 index 000000000000..ac6981666b57 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncDecPropertiesHelper.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.crypto.ColumnEncryptionProperties; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.crypto.FileEncryptionProperties; +import io.trino.parquet.crypto.ParquetCipher; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EncDecPropertiesHelper +{ + private EncDecPropertiesHelper() + { + } + + private static class DecryptionKeyRetrieverMock + implements DecryptionKeyRetriever + { + private final Map keyMap = new HashMap<>(); + + public DecryptionKeyRetrieverMock putKey(String keyId, byte[] keyBytes) + { + keyMap.put(keyId, keyBytes); + return this; + } + + @Override + public byte[] getKey(byte[] keyMetaData) + { + String keyId = new String(keyMetaData, StandardCharsets.UTF_8); + return keyMap.get(keyId); + } + } + + private static final byte[] FOOTER_KEY = {0x01, 0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}; + private static final byte[] FOOTER_KEY_METADATA = "footkey".getBytes(StandardCharsets.UTF_8); + private static final byte[] COL_KEY = {0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11}; + private static final byte[] COL_KEY_METADATA = "col".getBytes(StandardCharsets.UTF_8); + + public static FileDecryptionProperties getFileDecryptionProperties() + throws IOException + { + DecryptionKeyRetrieverMock keyRetriever = new DecryptionKeyRetrieverMock(); + keyRetriever.putKey("footkey", FOOTER_KEY); + keyRetriever.putKey("col", COL_KEY); + return FileDecryptionProperties.builder().withPlaintextFilesAllowed().withKeyRetriever(keyRetriever).build(); + } + + public static FileEncryptionProperties getFileEncryptionProperties(List encryptColumns, ParquetCipher cipher, Boolean encryptFooter) + { + if (encryptColumns.size() == 0) { + return null; + } + + Map columnPropertyMap = new HashMap<>(); + for (String encryptColumn : encryptColumns) { + ColumnPath columnPath = ColumnPath.fromDotString(encryptColumn); + ColumnEncryptionProperties columnEncryptionProperties = ColumnEncryptionProperties.builder(columnPath) + .withKey(COL_KEY) + .withKeyMetaData(COL_KEY_METADATA) + .build(); + columnPropertyMap.put(columnPath, columnEncryptionProperties); + } + + FileEncryptionProperties.Builder encryptionPropertiesBuilder = + FileEncryptionProperties.builder(FOOTER_KEY) + .withFooterKeyMetadata(FOOTER_KEY_METADATA) + .withAlgorithm(cipher) + .withEncryptedColumns(columnPropertyMap); + + if (!encryptFooter) { + encryptionPropertiesBuilder.withPlaintextFooter(); + } + + return encryptionPropertiesBuilder.build(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java new file mode 100644 index 000000000000..d7677525ef13 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import org.apache.parquet.example.data.simple.SimpleGroup; + +public class EncryptionTestFile +{ + private final String fileName; + private final SimpleGroup[] fileContent; + + public EncryptionTestFile(String fileName, SimpleGroup[] fileContent) + { + this.fileName = fileName; + this.fileContent = fileContent; + } + + public String getFileName() + { + return this.fileName; + } + + public SimpleGroup[] getFileContent() + { + return this.fileContent; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java new file mode 100644 index 000000000000..dd46ccb689b0 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class MockInputStreamTail +{ + public static final int MAX_SUPPORTED_PADDING_BYTES = 64; + private static final int MAXIMUM_READ_LENGTH = Integer.MAX_VALUE - (MAX_SUPPORTED_PADDING_BYTES + 1); + + private final Slice tailSlice; + private final long fileSize; + + private MockInputStreamTail(long fileSize, Slice tailSlice) + { + this.tailSlice = requireNonNull(tailSlice, "tailSlice is null"); + this.fileSize = fileSize; + checkArgument(fileSize >= 0, "fileSize is negative: %s", fileSize); + checkArgument(tailSlice.length() <= fileSize, "length (%s) is greater than fileSize (%s)", tailSlice.length(), fileSize); + } + + public static MockInputStreamTail readTail(String path, long paddedFileSize, FSDataInputStream inputStream, int length) + throws IOException + { + checkArgument(length >= 0, "length is negative: %s", length); + checkArgument(length <= MAXIMUM_READ_LENGTH, "length (%s) exceeds maximum (%s)", length, MAXIMUM_READ_LENGTH); + long readSize = min(paddedFileSize, (length + MAX_SUPPORTED_PADDING_BYTES)); + long position = paddedFileSize - readSize; + // Actual read will be 1 byte larger to ensure we encounter an EOF where expected + byte[] buffer = new byte[toIntExact(readSize + 1)]; + int bytesRead = 0; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + while (bytesRead < buffer.length) { + int n = inputStream.read(buffer, bytesRead, buffer.length - bytesRead); + if (n < 0) { + break; + } + bytesRead += n; + } + } + finally { + inputStream.seek(startPos); + } + if (bytesRead > readSize) { + throw rejectInvalidFileSize(path, paddedFileSize); + } + return new MockInputStreamTail(position + bytesRead, Slices.wrappedBuffer(buffer, max(0, bytesRead - length), min(bytesRead, length))); + } + + public static long readTailForFileSize(String path, long paddedFileSize, FSDataInputStream inputStream) + throws IOException + { + long position = max(paddedFileSize - MAX_SUPPORTED_PADDING_BYTES, 0); + long maxEOFAt = paddedFileSize + 1; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + int c; + while (position < maxEOFAt) { + c = inputStream.read(); + if (c < 0) { + return position; + } + position++; + } + throw rejectInvalidFileSize(path, paddedFileSize); + } + finally { + inputStream.seek(startPos); + } + } + + private static IOException rejectInvalidFileSize(String path, long reportedSize) + throws IOException + { + throw new IOException(format("Incorrect file size (%s) for file (end of stream not reached): %s", reportedSize, path)); + } + + public long getFileSize() + { + return fileSize; + } + + public Slice getTailSlice() + { + return tailSlice; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java new file mode 100644 index 000000000000..2652e2da3301 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java @@ -0,0 +1,335 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.units.DataSize; +import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.parquet.ChunkReader; +import io.trino.parquet.DiskRange; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.Math.toIntExact; +import static java.util.Comparator.comparingLong; +import static java.util.Objects.requireNonNull; + +public class MockParquetDataSource + implements ParquetDataSource +{ + private final ParquetDataSourceId id; + private final long estimatedSize; + private final FSDataInputStream inputStream; + private long readTimeNanos; + private long readBytes; + private final ParquetReaderOptions options; + + public MockParquetDataSource( + ParquetDataSourceId id, + long estimatedSize, + FSDataInputStream inputStream, + ParquetReaderOptions options) + { + this.id = requireNonNull(id, "id is null"); + this.estimatedSize = estimatedSize; + this.inputStream = inputStream; + this.options = requireNonNull(options, "options is null"); + } + + @Override + public ParquetDataSourceId getId() + { + return id; + } + + @Override + public final long getReadBytes() + { + return readBytes; + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos; + } + + @Override + public final long getEstimatedSize() + { + return estimatedSize; + } + + @Override + public void close() + throws IOException + { + inputStream.close(); + } + + @Override + public Slice readTail(int length) + { + long start = System.nanoTime(); + Slice tailSlice; + try { + // Handle potentially imprecise file lengths by reading the footer + MockInputStreamTail fileTail = MockInputStreamTail.readTail(getId().toString(), getEstimatedSize(), inputStream, length); + tailSlice = fileTail.getTailSlice(); + } + catch (IOException e) { + throw new RuntimeException("Error reading tail from %s with length"); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + readBytes += tailSlice.length(); + return tailSlice; + } + + @Override + public final Slice readFully(long position, int length) + { + byte[] buffer = new byte[length]; + readFully(position, buffer, 0, length); + return Slices.wrappedBuffer(buffer); + } + + @Override + public final Map planRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableMap.of(); + } + + return planChunksRead(diskRanges, memoryContext).asMap() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> new ChunkedInputStream(entry.getValue()))); + } + + public ListMultimap planChunksRead(ListMultimap diskRanges, AggregatedMemoryContext memoryContext) + { + requireNonNull(diskRanges, "diskRanges is null"); + + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + // + // Note: this code does not use the stream APIs to avoid any extra object allocation + // + + // split disk ranges into "big" and "small" + ImmutableListMultimap.Builder smallRangesBuilder = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder largeRangesBuilder = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + if (entry.getValue().getLength() <= options.getMaxBufferSize().toBytes()) { + smallRangesBuilder.put(entry); + } + else { + largeRangesBuilder.put(entry); + } + } + ListMultimap smallRanges = smallRangesBuilder.build(); + ListMultimap largeRanges = largeRangesBuilder.build(); + + // read ranges + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + slices.putAll(readSmallDiskRanges(smallRanges)); + slices.putAll(readLargeDiskRanges(largeRanges)); + + return slices.build(); + } + + private void readFully(long position, byte[] buffer, int bufferOffset, int bufferLength) + { + readBytes += bufferLength; + + long start = System.nanoTime(); + try { + inputStream.readFully(position, buffer, bufferOffset, bufferLength); + } + catch (Exception e) { + throw new RuntimeException("Error reading from %s " + id + " at position " + position); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + } + + private ListMultimap readSmallDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + Iterable mergedRanges = mergeAdjacentDiskRanges(diskRanges.values(), options.getMaxMergeDistance(), options.getMaxBufferSize()); + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (DiskRange mergedRange : mergedRanges) { + ReferenceCountedReader mergedRangeLoader = new ReferenceCountedReader(mergedRange); + + for (Map.Entry diskRangeEntry : diskRanges.entries()) { + DiskRange diskRange = diskRangeEntry.getValue(); + if (mergedRange.contains(diskRange)) { + mergedRangeLoader.addReference(); + + slices.put(diskRangeEntry.getKey(), new ChunkReader() + { + @Override + public Slice read() + { + int offset = toIntExact(diskRange.getOffset() - mergedRange.getOffset()); + return mergedRangeLoader.read().slice(offset, Long.valueOf(diskRange.getLength()).intValue()); + } + + @Override + public void free() + { + mergedRangeLoader.free(); + } + + @Override + public long getDiskOffset() + { + return diskRange.getOffset(); + } + }); + } + } + + mergedRangeLoader.free(); + } + + ListMultimap sliceStreams = slices.build(); + verify(sliceStreams.keySet().equals(diskRanges.keySet())); + return sliceStreams; + } + + private ListMultimap readLargeDiskRanges(ListMultimap diskRanges) + { + if (diskRanges.isEmpty()) { + return ImmutableListMultimap.of(); + } + + ImmutableListMultimap.Builder slices = ImmutableListMultimap.builder(); + for (Map.Entry entry : diskRanges.entries()) { + slices.put(entry.getKey(), new ReferenceCountedReader(entry.getValue())); + } + return slices.build(); + } + + private static List mergeAdjacentDiskRanges(Collection diskRanges, DataSize maxMergeDistance, DataSize maxReadSize) + { + // sort ranges by start offset + List ranges = new ArrayList<>(diskRanges); + ranges.sort(comparingLong(DiskRange::getOffset)); + + long maxReadSizeBytes = maxReadSize.toBytes(); + long maxMergeDistanceBytes = maxMergeDistance.toBytes(); + + // merge overlapping ranges + ImmutableList.Builder result = ImmutableList.builder(); + DiskRange last = ranges.get(0); + for (int i = 1; i < ranges.size(); i++) { + DiskRange current = ranges.get(i); + DiskRange merged = null; + boolean blockTooLong = false; + try { + merged = last.span(current); + } + catch (ArithmeticException e) { + blockTooLong = true; + } + if (!blockTooLong && merged.getLength() <= maxReadSizeBytes && last.getEnd() + maxMergeDistanceBytes >= current.getOffset()) { + last = merged; + } + else { + result.add(last); + last = current; + } + } + result.add(last); + + return result.build(); + } + + private class ReferenceCountedReader + implements ChunkReader + { + private final DiskRange range; + private Slice data; + private int referenceCount = 1; + + public ReferenceCountedReader(DiskRange range) + { + this.range = range; + } + + public void addReference() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + referenceCount++; + } + + @Override + public Slice read() + { + checkState(referenceCount > 0, "Chunk reader is already closed"); + + if (data == null) { + byte[] buffer = new byte[Long.valueOf(range.getLength()).intValue()]; + readFully(range.getOffset(), buffer, 0, buffer.length); + data = Slices.wrappedBuffer(buffer); + } + + return data; + } + + @Override + public void free() + { + checkState(referenceCount > 0, "Reference count is already 0"); + + referenceCount--; + if (referenceCount == 0) { + data = null; + } + } + + @Override + public long getDiskOffset() + { + return range.getOffset(); + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java index d42725e5acb2..7f448bdbed2d 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java @@ -50,7 +50,7 @@ public void testReadFloatDouble() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("byte_stream_split_float_and_double.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); readAndCompare(reader, getExpectedValues()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..c178d5be0261 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableSet; +import io.trino.parquet.crypto.HiddenColumnChunkMetaData; +import io.trino.parquet.crypto.HiddenColumnException; +import io.trino.parquet.metadata.ColumnChunkMetadata; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Types; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.parquet.column.Encoding.PLAIN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestHiddenColumnChunkMetaData +{ + @Test + public void testIsHiddenColumn() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isTrue(); + } + + @Test + public void testIsNotHiddenColumn() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + assertThat(HiddenColumnChunkMetaData.isHiddenColumn(column)).isFalse(); + } + + @Test(expectedExceptions = HiddenColumnException.class) + public void testHiddenColumnException() + { + ColumnChunkMetadata column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + column.getStatistics(); + } + + @Test + public void testNoHiddenColumnException() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + EncodingStats encodingStats = new EncodingStats.Builder() + .withV2Pages() + .addDictEncoding(PLAIN) + .addDataEncodings(ImmutableSet.copyOf(encodingSet)).build(); + PrimitiveType type = Types.optional(BINARY).named(""); + Statistics> stats = Statistics.createStats(type); + ColumnChunkMetadata column = ColumnChunkMetadata.get(ColumnPath.fromDotString("a.b.c"), type, + CompressionCodecName.GZIP, encodingStats, encodingSet, stats, -1, -1, -1, -1, -1); + column.getStatistics(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index aabb734e5b0c..49e4fc2f9d80 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -112,7 +112,7 @@ public void testNanosOutsideDayRange() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("int96_timestamps_nanos_outside_day_range.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); @@ -166,11 +166,12 @@ private void testVariousTimestamps(TimestampType type) slice.length(), OptionalLong.empty(), null, - false); + false, + 0); // Read and assert ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(DateTimeZone.UTC, new ParquetReaderOptions()); ColumnReader reader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false, null, null, -1, -1); reader.setPageReader(pageReader, Optional.empty()); reader.prepareNextRead(valueCount); Block block = reader.readPrimitive().getBlock(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java index 102e2b4fc01b..a94ff78cf8f2 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java @@ -25,6 +25,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetTypeUtils; +import io.trino.parquet.crypto.InternalFileDecryptor; import io.trino.parquet.metadata.ColumnChunkMetadata; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.EncodingStats; @@ -183,7 +184,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP out.write(compressedDataPage); byte[] bytes = out.toByteArray(); - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); DictionaryPage uncompressedDictionaryPage = pageReader.readDictionaryPage(); assertThat(uncompressedDictionaryPage.getDictionarySize()).isEqualTo(dictionaryPageHeader.getDictionary_page_header().getNum_values()); assertEncodingEquals(uncompressedDictionaryPage.getEncoding(), dictionaryPageHeader.getDictionary_page_header().getEncoding()); @@ -193,7 +194,7 @@ public void dictionaryPage(CompressionCodec compressionCodec, DataPageType dataP assertPages(compressionCodec, totalValueCount, 3, pageHeader, compressedDataPage, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); // only dictionary - pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize)))); + pageReader = createPageReader(0, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(Arrays.copyOf(bytes, dictionaryPageSize))), null, -1); assertThatThrownBy(pageReader::readDictionaryPage) .isInstanceOf(IllegalStateException.class) .hasMessageStartingWith("No more data left to read"); @@ -236,7 +237,7 @@ public void dictionaryPageNotFirst() int totalValueCount = valueCount * 2; // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(totalValueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNull(); assertThat(pageReader.readPage()).isNotNull(); assertThatThrownBy(pageReader::readPage) @@ -270,7 +271,7 @@ public void unusedDictionaryPage() byte[] bytes = out.toByteArray(); // There is a dictionary, but it's there as the second page - PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes))); + PageReader pageReader = createPageReader(valueCount, compressionCodec, true, ImmutableList.of(Slices.wrappedBuffer(bytes)), null, -1); assertThat(pageReader.readDictionaryPage()).isNotNull(); assertThat(pageReader.readPage()).isNotNull(); assertThat(pageReader.readPage()).isNull(); @@ -298,7 +299,7 @@ private static void assertPages( List slices) throws IOException { - PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices); + PageReader pageReader = createPageReader(valueCount, compressionCodec, hasDictionary, slices, null, -1); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); assertThat(dictionaryPage != null).isEqualTo(hasDictionary); @@ -383,7 +384,7 @@ private static byte[] compress(CompressionCodec compressionCodec, byte[] bytes, throw new IllegalArgumentException("unsupported compression code " + compressionCodec); } - private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices) + private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices, InternalFileDecryptor fileDecryptor, int rowGroupOrdinal) throws IOException { EncodingStats.Builder encodingStats = new EncodingStats.Builder(); @@ -409,7 +410,8 @@ private static PageReader createPageReader(int valueCount, CompressionCodec comp columnChunkMetaData, new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, INT32, ""), 0, 0), null, - Optional.empty()); + Optional.empty(), + Optional.ofNullable(fileDecryptor)); } private static void assertDataPageEquals(PageHeader pageHeader, byte[] dataPage, byte[] compressedDataPage, DataPage decompressedPage) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java index 2ef475a7644f..0c4f3011dbb1 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java @@ -79,7 +79,7 @@ public void testColumnReaderMemoryUsage() columnNames, generateInputPages(types, 100, 5)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThan(1); // Verify file has only non-dictionary encodings as dictionary memory usage is already tested in TestFlatColumnReader#testMemoryUsage parquetMetadata.getBlocks().forEach(block -> { @@ -132,7 +132,7 @@ public void testEmptyRowRangesWithColumnIndex() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("lineitem_sorted_by_shipdate/data.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(2); // The predicate and the file are prepared so that page indexes will result in non-overlapping row ranges and eliminate the entire first row group // while the second row group still has to be read @@ -193,7 +193,7 @@ private void testReadingOldParquetFiles(File file, List columnNames, Typ file, new ParquetReaderOptions()); ConnectorSession session = TestingConnectorSession.builder().build(); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), ImmutableList.of(columnType), columnNames)) { Page page = reader.nextPage(); Iterator> expected = expectedValues.iterator(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index 390608f445a9..99ae226bca08 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -60,7 +60,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("time_millis_int32.snappy.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java index a3efb46b6d71..8222899ab90b 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java @@ -137,8 +137,9 @@ private static PageReader getSimplePageReaderMock(ParquetEncoding encoding) OptionalLong.empty(), encoding, encoding, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } private static PageReader getNullOnlyPageReaderMock() @@ -154,7 +155,8 @@ private static PageReader getNullOnlyPageReaderMock() OptionalLong.empty(), RLE, RLE, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, null, null, -1, -1); } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index 846080c3297a..717474419d11 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -127,7 +127,7 @@ public void testWrittenPageSize() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -141,6 +141,7 @@ public void testWrittenPageSize() chunkMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -176,7 +177,7 @@ public void testWrittenPageValueCount() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -194,6 +195,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -213,6 +215,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columnb"}, new PrimitiveType(REQUIRED, INT64, "columnb"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -256,8 +259,7 @@ public void testLargeStringTruncation() columnNames, ImmutableList.of(new Page(2, blockA, blockB))), new ParquetReaderOptions()); - - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); BlockMetadata blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetadata chunkMetaData = blockMetaData.columns().get(0); @@ -290,7 +292,7 @@ public void testColumnReordering() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(10); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { // Verify that the columns are stored in the same order as the metadata @@ -347,7 +349,7 @@ public void testDictionaryPageOffset() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(1); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { ColumnChunkMetadata chunkMetaData = getOnlyElement(blockMetaData.columns()); @@ -393,7 +395,7 @@ public void testWriteBloomFilters(Type type, List> data) generateInputPages(types, 100, data)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); // Check that bloom filters are right after each other int bloomFilterSize = Integer.highestOneBit(BlockSplitBloomFilter.optimalNumOfBits(BLOOM_FILTER_EXPECTED_ENTRIES, DEFAULT_BLOOM_FILTER_FPP) / 8) << 1; for (BlockMetadata block : parquetMetadata.getBlocks()) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index 5fe764a72756..eb0a41cd8108 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -362,7 +362,7 @@ private Slice writeMergeResult(Slice path, FileDeletion deletion) TrinoInputFile inputFile = fileSystem.newInputFile(Location.of(path.toStringUtf8())); try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, parquetReaderOptions, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); long rowCount = parquetMetadata.getBlocks().stream().map(BlockMetadata::rowCount).mapToLong(Long::longValue).sum(); RoaringBitmapArray rowsRetained = new RoaringBitmapArray(); rowsRetained.addRange(0, rowCount - 1); @@ -637,7 +637,8 @@ private ReaderPageSource createParquetPageSource(Location path) new ParquetReaderOptions().withBloomFilter(false), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + null); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index f08ecc84f839..c552b1944e2c 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -254,7 +254,8 @@ public ConnectorPageSource createPageSource( options, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(split.getFileSize())); + OptionalLong.of(split.getFileSize()), + null); Optional projectionsAdapter = pageSource.getReaderColumns().map(readerColumns -> new ReaderProjectionsAdapter( @@ -306,7 +307,7 @@ private PositionDeleteFilter readDeletes( public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index 8f686205e239..5330c6edd100 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -184,7 +184,7 @@ public DataFileInfo getDataFileInfo() { Location path = rootTableLocation.appendPath(relativeFilePath); FileMetaData fileMetaData = fileWriter.getFileMetadata(); - ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString())); + ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString()), Optional.empty(), false); return new DataFileInfo( relativeFilePath, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index 7f5d4b8a88c6..3c57de2ef2f3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -205,7 +205,8 @@ private static DeltaLakePageSource createDeltaLakePageSource( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.empty()); + OptionalLong.of(split.fileSize()), + null); verify(pageSource.getReaderColumns().isEmpty(), "Unexpected reader columns: %s", pageSource.getReaderColumns().orElse(null)); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 04673aeab8ea..985cc433aaea 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -231,7 +231,8 @@ public CheckpointEntryIterator( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + Optional.empty()); this.pageSource = (ParquetPageSource) pageSource.get(); try { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index 70cdce9c5e4f..7f9050feb512 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -329,7 +329,7 @@ private void testOptimizeWithColumnMappingMode(String columnMappingMode) TrinoInputFile inputFile = new LocalInputFile(tableLocation.resolve(addFileEntry.getPath()).toFile()); ParquetMetadata parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); PrimitiveType physicalType = getOnlyElement(fileMetaData.getSchema().getColumns().iterator()).getPrimitiveType(); assertThat(physicalType.getName()).isEqualTo(physicalName); diff --git a/plugin/trino-geospatial/pom.xml b/plugin/trino-geospatial/pom.xml index 6d975cb4232a..1636b6f25102 100644 --- a/plugin/trino-geospatial/pom.xml +++ b/plugin/trino-geospatial/pom.xml @@ -230,4 +230,20 @@ test