diff --git a/lib/trino-parquet/pom.xml b/lib/trino-parquet/pom.xml index 4bf90122f230..b0a4108c7e8d 100644 --- a/lib/trino-parquet/pom.xml +++ b/lib/trino-parquet/pom.xml @@ -47,6 +47,11 @@ units + + io.trino + trino-filesystem + + io.trino trino-memory-context diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java index bbece17c9b7e..123ddac61520 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java @@ -13,6 +13,8 @@ */ package io.trino.parquet; +import io.airlift.slice.Slice; + import java.util.OptionalLong; public abstract sealed class DataPage @@ -41,4 +43,6 @@ public int getValueCount() { return valueCount; } + + public abstract Slice getSlice(); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java new file mode 100644 index 000000000000..c05584a624bf --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/EncryptionUtils.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import io.airlift.log.Logger; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.InternalFileDecryptor; +import org.apache.parquet.crypto.TrinoCryptoConfigurationUtil; +import org.apache.parquet.crypto.TrinoDecryptionPropertiesFactory; + +import java.lang.reflect.InvocationTargetException; +import java.util.Optional; + +public class EncryptionUtils +{ + public static final Logger LOG = Logger.get(EncryptionUtils.class); + + private EncryptionUtils() {} + + public static Optional createDecryptor(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + { + if (parquetReaderOptions == null || filePath == null || trinoFileSystem == null) { + return Optional.empty(); + } + + TrinoDecryptionPropertiesFactory cryptoFactory = loadDecryptionPropertiesFactory(parquetReaderOptions); + FileDecryptionProperties fileDecryptionProperties = (cryptoFactory == null) ? null : cryptoFactory.getFileDecryptionProperties(parquetReaderOptions, filePath, trinoFileSystem); + return (fileDecryptionProperties == null) ? Optional.empty() : Optional.of(new InternalFileDecryptor(fileDecryptionProperties)); + } + + private static TrinoDecryptionPropertiesFactory loadDecryptionPropertiesFactory(ParquetReaderOptions trinoParquetCryptoConfig) + { + final Class foundClass = TrinoCryptoConfigurationUtil.getClassFromConfig( + trinoParquetCryptoConfig.getCryptoFactoryClass(), TrinoDecryptionPropertiesFactory.class); + + if (foundClass == null) { + return null; + } + + try { + return (TrinoDecryptionPropertiesFactory) foundClass.getConstructor().newInstance(); + } + catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + LOG.warn("could not instantiate decryptionPropertiesFactoryClass class: " + foundClass, e); + return null; + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java index e0e7d4418fbb..237d58061b44 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetReaderOptions.java @@ -14,6 +14,8 @@ package io.trino.parquet; import io.airlift.units.DataSize; +import org.apache.parquet.crypto.keytools.KmsClient; +import org.apache.parquet.crypto.keytools.TrinoKeyToolkit; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.units.DataSize.Unit.MEGABYTE; @@ -35,6 +37,12 @@ public class ParquetReaderOptions private final boolean useColumnIndex; private final boolean useBloomFilter; private final DataSize smallFileThreshold; + private final String cryptoFactoryClass; + private final String encryptionKmsClientClass; + private final String encryptionKmsInstanceId; + private final String encryptionKmsInstanceUrl; + private final String encryptionKeyAccessToken; + private final long encryptionCacheLifetimeSeconds; public ParquetReaderOptions() { @@ -46,6 +54,12 @@ public ParquetReaderOptions() useColumnIndex = true; useBloomFilter = true; smallFileThreshold = DEFAULT_SMALL_FILE_THRESHOLD; + this.cryptoFactoryClass = null; + this.encryptionKmsClientClass = null; + this.encryptionKmsInstanceId = null; + this.encryptionKmsInstanceUrl = null; + this.encryptionKeyAccessToken = KmsClient.KEY_ACCESS_TOKEN_DEFAULT; + this.encryptionCacheLifetimeSeconds = TrinoKeyToolkit.CACHE_LIFETIME_DEFAULT_SECONDS; } private ParquetReaderOptions( @@ -56,7 +70,13 @@ private ParquetReaderOptions( DataSize maxBufferSize, boolean useColumnIndex, boolean useBloomFilter, - DataSize smallFileThreshold) + DataSize smallFileThreshold, + String cryptoFactoryClass, + String encryptionKmsClientClass, + String encryptionKmsInstanceId, + String encryptionKmsInstanceUrl, + String encryptionKeyAccessToken, + long encryptionCacheLifetimeSeconds) { this.ignoreStatistics = ignoreStatistics; this.maxReadBlockSize = requireNonNull(maxReadBlockSize, "maxReadBlockSize is null"); @@ -67,6 +87,12 @@ private ParquetReaderOptions( this.useColumnIndex = useColumnIndex; this.useBloomFilter = useBloomFilter; this.smallFileThreshold = requireNonNull(smallFileThreshold, "smallFileThreshold is null"); + this.cryptoFactoryClass = cryptoFactoryClass; + this.encryptionKmsClientClass = encryptionKmsClientClass; + this.encryptionKmsInstanceId = encryptionKmsInstanceId; + this.encryptionKmsInstanceUrl = encryptionKmsInstanceUrl; + this.encryptionKeyAccessToken = requireNonNull(encryptionKeyAccessToken, "encryptionKeyAccessToken is null"); + this.encryptionCacheLifetimeSeconds = encryptionCacheLifetimeSeconds; } public boolean isIgnoreStatistics() @@ -109,6 +135,36 @@ public DataSize getSmallFileThreshold() return smallFileThreshold; } + public String getCryptoFactoryClass() + { + return cryptoFactoryClass; + } + + public long getEncryptionCacheLifetimeSeconds() + { + return this.encryptionCacheLifetimeSeconds; + } + + public String getEncryptionKeyAccessToken() + { + return this.encryptionKeyAccessToken; + } + + public String getEncryptionKmsInstanceId() + { + return this.encryptionKmsInstanceId; + } + + public String getEncryptionKmsInstanceUrl() + { + return this.encryptionKmsInstanceUrl; + } + + public String getEncryptionKmsClientClass() + { + return this.encryptionKmsClientClass; + } + public ParquetReaderOptions withIgnoreStatistics(boolean ignoreStatistics) { return new ParquetReaderOptions( @@ -119,7 +175,13 @@ public ParquetReaderOptions withIgnoreStatistics(boolean ignoreStatistics) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withMaxReadBlockSize(DataSize maxReadBlockSize) @@ -132,7 +194,13 @@ public ParquetReaderOptions withMaxReadBlockSize(DataSize maxReadBlockSize) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withMaxReadBlockRowCount(int maxReadBlockRowCount) @@ -145,7 +213,13 @@ public ParquetReaderOptions withMaxReadBlockRowCount(int maxReadBlockRowCount) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) @@ -158,7 +232,13 @@ public ParquetReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withMaxBufferSize(DataSize maxBufferSize) @@ -171,7 +251,13 @@ public ParquetReaderOptions withMaxBufferSize(DataSize maxBufferSize) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withUseColumnIndex(boolean useColumnIndex) @@ -184,7 +270,13 @@ public ParquetReaderOptions withUseColumnIndex(boolean useColumnIndex) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withBloomFilter(boolean useBloomFilter) @@ -197,7 +289,13 @@ public ParquetReaderOptions withBloomFilter(boolean useBloomFilter) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } public ParquetReaderOptions withSmallFileThreshold(DataSize smallFileThreshold) @@ -210,6 +308,126 @@ public ParquetReaderOptions withSmallFileThreshold(DataSize smallFileThreshold) maxBufferSize, useColumnIndex, useBloomFilter, - smallFileThreshold); + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withCryptoFactoryClass(String cryptoFactoryClass) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withEncryptionKmsClientClass(String encryptionKmsClientClass) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withEncryptionKmsInstanceId(String encryptionKmsInstanceId) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withEncryptionKmsInstanceUrl(String encryptionKmsInstanceUrl) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withEncryptionKeyAccessToken(String encryptionKeyAccessToken) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); + } + + public ParquetReaderOptions withEncryptionCacheLifetimeSeconds(long encryptionCacheLifetimeSeconds) + { + return new ParquetReaderOptions( + ignoreStatistics, + maxReadBlockSize, + maxReadBlockRowCount, + maxMergeDistance, + maxBufferSize, + useColumnIndex, + useBloomFilter, + smallFileThreshold, + cryptoFactoryClass, + encryptionKmsClientClass, + encryptionKmsInstanceId, + encryptionKmsInstanceUrl, + encryptionKeyAccessToken, + encryptionCacheLifetimeSeconds); } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index 3768645bf199..d632ceac874d 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -14,23 +14,37 @@ package io.trino.parquet.reader; import io.airlift.log.Logger; +import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetWriteValidation; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.crypto.AesCipher; +import org.apache.parquet.crypto.AesGcmEncryptor; +import org.apache.parquet.crypto.HiddenColumnChunkMetaData; +import org.apache.parquet.crypto.InternalColumnDecryptionSetup; +import org.apache.parquet.crypto.InternalFileDecryptor; +import org.apache.parquet.crypto.KeyAccessDeniedException; +import org.apache.parquet.crypto.ModuleCipherFactory; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +import org.apache.parquet.crypto.TagVerificationException; +import org.apache.parquet.format.BlockCipher.Decryptor; import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnMetaData; import org.apache.parquet.format.Encoding; +import org.apache.parquet.format.EncryptionWithColumnKey; +import org.apache.parquet.format.FileCryptoMetaData; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.SchemaElement; import org.apache.parquet.format.Statistics; import org.apache.parquet.format.Type; +import org.apache.parquet.format.Util; import org.apache.parquet.format.converter.ParquetMetadataConverter; import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; @@ -45,8 +59,8 @@ import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -58,20 +72,31 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetValidationUtils.validateParquet; import static java.lang.Boolean.FALSE; import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static org.apache.parquet.crypto.AesCipher.GCM_TAG_LENGTH; +import static org.apache.parquet.crypto.AesCipher.NONCE_LENGTH; +import static org.apache.parquet.format.Util.readFileCryptoMetaData; import static org.apache.parquet.format.Util.readFileMetaData; import static org.apache.parquet.format.converter.ParquetMetadataConverterUtil.getLogicalTypeAnnotation; +import static org.apache.parquet.hadoop.ParquetFileWriter.EF_MAGIC_STR; +import static org.apache.parquet.hadoop.ParquetFileWriter.MAGIC_STR; public final class MetadataReader { private static final Logger log = Logger.get(MetadataReader.class); - private static final Slice MAGIC = Slices.utf8Slice("PAR1"); + private static final Slice MAGIC = wrappedBuffer(MAGIC_STR.getBytes(US_ASCII)); + private static final Slice EMAGIC = wrappedBuffer(EF_MAGIC_STR.getBytes(US_ASCII)); + private static final int POST_SCRIPT_SIZE = Integer.BYTES + MAGIC.length(); // Typical 1GB files produced by Trino were found to have footer size between 30-40KB private static final int EXPECTED_FOOTER_SIZE = 48 * 1024; @@ -79,7 +104,66 @@ public final class MetadataReader private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) + private static void verifyFooterIntegrity(BasicSliceInput from, InternalFileDecryptor fileDecryptor, int combinedFooterLength) + { + byte[] nonce = new byte[NONCE_LENGTH]; + from.read(nonce); + byte[] gcmTag = new byte[GCM_TAG_LENGTH]; + from.read(gcmTag); + + AesGcmEncryptor footerSigner = fileDecryptor.createSignedFooterEncryptor(); + int footerSignatureLength = NONCE_LENGTH + GCM_TAG_LENGTH; + byte[] serializedFooter = new byte[combinedFooterLength - footerSignatureLength]; + from.setPosition(0); + from.read(serializedFooter, 0, serializedFooter.length); + + byte[] signedFooterAuthenticationData = AesCipher.createFooterAAD(fileDecryptor.getFileAAD()); + byte[] encryptedFooterBytes = footerSigner.encrypt(false, serializedFooter, nonce, signedFooterAuthenticationData); + byte[] calculatedTag = new byte[GCM_TAG_LENGTH]; + System.arraycopy(encryptedFooterBytes, encryptedFooterBytes.length - GCM_TAG_LENGTH, calculatedTag, 0, GCM_TAG_LENGTH); + if (!Arrays.equals(gcmTag, calculatedTag)) { + throw new TagVerificationException("Signature mismatch in plaintext footer"); + } + } + + private static ColumnMetaData decryptMetadata(RowGroup rowGroup, ColumnCryptoMetaData cryptoMetaData, ColumnChunk columnChunk, InternalFileDecryptor fileDecryptor, int columnOrdinal) + { + EncryptionWithColumnKey columnKeyStruct = cryptoMetaData.getENCRYPTION_WITH_COLUMN_KEY(); + List pathList = columnKeyStruct.getPath_in_schema(); + pathList = pathList.stream().map(String::toLowerCase).collect(Collectors.toList()); + byte[] columnKeyMetadata = columnKeyStruct.getKey_metadata(); + ColumnPath columnPath = ColumnPath.get(pathList.toArray(new String[pathList.size()])); + byte[] encryptedMetadataBuffer = columnChunk.getEncrypted_column_metadata(); + + // Decrypt the ColumnMetaData + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.setColumnCryptoMetadata(columnPath, true, false, columnKeyMetadata, columnOrdinal); + ByteArrayInputStream tempInputStream = new ByteArrayInputStream(encryptedMetadataBuffer); + byte[] columnMetaDataAAD = AesCipher.createModuleAAD(fileDecryptor.getFileAAD(), ModuleCipherFactory.ModuleType.ColumnMetaData, rowGroup.ordinal, columnOrdinal, -1); + try { + return Util.readColumnMetaData(tempInputStream, columnDecryptionSetup.getMetaDataDecryptor(), columnMetaDataAAD); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException(columnPath + ". Failed to decrypt column metadata", e); + } + } + + public static ColumnChunkMetaData buildColumnChunkMetaData(Optional fileCreatedBy, ColumnMetaData metaData, ColumnPath columnPath, PrimitiveType type) + { + return ColumnChunkMetaData.get( + columnPath, + type, + CompressionCodecName.fromParquet(metaData.codec), + PARQUET_METADATA_CONVERTER.convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + readStats(fileCreatedBy, Optional.ofNullable(metaData.statistics), type), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + } + + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation, Optional fileDecryptor) throws IOException { // Parquet File Layout: @@ -98,7 +182,9 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< Slice buffer = dataSource.readTail(toIntExact(expectedReadSize)); Slice magic = buffer.slice(buffer.length() - MAGIC.length(), MAGIC.length()); - validateParquet(MAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + validateParquet(MAGIC.equals(magic) || EMAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + boolean encryptedFooterMode = EMAGIC.equals(magic); + checkArgument(!encryptedFooterMode || fileDecryptor.isPresent(), "fileDecryptionProperties cannot be null when encryptedFooterMode is true"); int metadataLength = buffer.getInt(buffer.length() - POST_SCRIPT_SIZE); long metadataIndex = estimatedFileSize - POST_SCRIPT_SIZE - metadataLength; @@ -113,15 +199,45 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< // initial read was not large enough, so just read again with the correct size buffer = dataSource.readTail(completeFooterSize); } - InputStream metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); + BasicSliceInput metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); + + Decryptor footerDecryptor = null; + // additional authenticated data for AES cipher + byte[] additionalAuthenticationData = null; + + if (encryptedFooterMode) { + FileCryptoMetaData fileCryptoMetaData = readFileCryptoMetaData(metadataStream); + fileDecryptor.get().setFileCryptoMetaData(fileCryptoMetaData.getEncryption_algorithm(), true, fileCryptoMetaData.getKey_metadata()); + footerDecryptor = fileDecryptor.get().fetchFooterDecryptor(); + additionalAuthenticationData = AesCipher.createFooterAAD(fileDecryptor.get().getFileAAD()); + } + + FileMetaData fileMetaData = readFileMetaData(metadataStream, footerDecryptor, additionalAuthenticationData); + // Reader attached fileDecryptor. The file could be encrypted with plaintext footer or the whole file is plaintext. + if (!encryptedFooterMode && fileDecryptor.isPresent()) { + if (!fileMetaData.isSetEncryption_algorithm()) { // Plaintext file + fileDecryptor.get().setPlaintextFile(); + // Detect that the file is not encrypted by mistake + if (!fileDecryptor.get().plaintextFilesAllowed()) { + throw new ParquetCryptoRuntimeException("Applying decryptor on plaintext file"); + } + } + else { // Encrypted file with plaintext footer + // if no fileDecryptor, can still read plaintext columns + fileDecryptor.get().setFileCryptoMetaData(fileMetaData.getEncryption_algorithm(), false, + fileMetaData.getFooter_signing_key_metadata()); + if (fileDecryptor.get().checkFooterIntegrity()) { + verifyFooterIntegrity(metadataStream, fileDecryptor.get(), metadataLength); + } + } + } - FileMetaData fileMetaData = readFileMetaData(metadataStream); - ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, dataSource.getId()); + ParquetMetadata parquetMetadata = createParquetMetadata(fileMetaData, dataSource.getId(), fileDecryptor, encryptedFooterMode); validateFileMetadata(dataSource.getId(), parquetMetadata.getFileMetaData(), parquetWriteValidation); return parquetMetadata; } - public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, ParquetDataSourceId dataSourceId) + public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, ParquetDataSourceId dataSourceId, Optional fileDecryptor, boolean encryptedFooterMode) throws ParquetCorruptionException { List schema = fileMetaData.getSchema(); @@ -138,30 +254,57 @@ public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, P List columns = rowGroup.getColumns(); validateParquet(!columns.isEmpty(), dataSourceId, "No columns in row group: %s", rowGroup); String filePath = columns.get(0).getFile_path(); + int columnOrdinal = -1; + for (ColumnChunk columnChunk : columns) { validateParquet( (filePath == null && columnChunk.getFile_path() == null) || (filePath != null && filePath.equals(columnChunk.getFile_path())), dataSourceId, "all column chunks of the same row group must be in the same file"); + + columnOrdinal++; ColumnMetaData metaData = columnChunk.meta_data; - String[] path = metaData.path_in_schema.stream() - .map(value -> value.toLowerCase(Locale.ENGLISH)) - .toArray(String[]::new); - ColumnPath columnPath = ColumnPath.get(path); - PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); - ColumnChunkMetaData column = ColumnChunkMetaData.get( - columnPath, - primitiveType, - CompressionCodecName.fromParquet(metaData.codec), - PARQUET_METADATA_CONVERTER.convertEncodingStats(metaData.encoding_stats), - readEncodings(metaData.encodings), - readStats(Optional.ofNullable(fileMetaData.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), - metaData.data_page_offset, - metaData.dictionary_page_offset, - metaData.num_values, - metaData.total_compressed_size, - metaData.total_uncompressed_size); + ColumnCryptoMetaData cryptoMetaData = columnChunk.getCrypto_metadata(); + ColumnPath columnPath = null; + + if (null == cryptoMetaData) { // Plaintext column + columnPath = getColumnPath(metaData); + if (fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile()) { + // mark this column as plaintext in encrypted file decryptor + fileDecryptor.get().setColumnCryptoMetadata(columnPath, false, false, (byte[]) null, columnOrdinal); + } + } + else { // Encrypted column + if (cryptoMetaData.isSetENCRYPTION_WITH_FOOTER_KEY()) { // Column encrypted with footer key + if (!encryptedFooterMode) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key in file with plaintext footer"); + } + if (null == metaData) { + throw new ParquetCryptoRuntimeException("ColumnMetaData not set in Encryption with Footer key"); + } + if (!fileDecryptor.isPresent()) { + throw new ParquetCryptoRuntimeException("Column encrypted with footer key: No keys available"); + } + columnPath = getColumnPath(metaData); + fileDecryptor.get().setColumnCryptoMetadata(columnPath, true, true, (byte[]) null, columnOrdinal); + } + else { // Column encrypted with column key + try { + // TODO: We decrypted data before filter projection. This could send unnecessary traffic to KMS. This so far not seen a problem in production. + // In parquet-mr, it uses lazy decryption but that required to change ColumnChunkMetadata. We will improve it later. + metaData = decryptMetadata(rowGroup, cryptoMetaData, columnChunk, fileDecryptor.get(), columnOrdinal); + columnPath = getColumnPath(metaData); + } + catch (KeyAccessDeniedException e) { + ColumnChunkMetaData column = new HiddenColumnChunkMetaData(columnPath, filePath); + blockMetaData.addColumn(column); + continue; + } + } + } + + ColumnChunkMetaData column = buildColumnChunkMetaData(Optional.ofNullable(fileMetaData.getCreated_by()), metaData, columnPath, messageType.getType(columnPath.toArray()).asPrimitiveType()); column.setColumnIndexReference(toColumnIndexReference(columnChunk)); column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); column.setBloomFilterOffset(metaData.bloom_filter_offset); @@ -186,6 +329,16 @@ public static ParquetMetadata createParquetMetadata(FileMetaData fileMetaData, P return new ParquetMetadata(parquetFileMetadata, blocks); } + private static ColumnPath getColumnPath(ColumnMetaData metaData) + { + ColumnPath columnPath; + String[] path = metaData.path_in_schema.stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)) + .toArray(String[]::new); + columnPath = ColumnPath.get(path); + return columnPath; + } + private static MessageType readParquetSchema(List schema) { Iterator schemaIterator = schema.iterator(); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index 7dddae62af4d..0ac4525b3d5d 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; import io.trino.parquet.DataPage; import io.trino.parquet.DataPageV1; import io.trino.parquet.DataPageV2; @@ -25,8 +26,14 @@ import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.crypto.AesCipher; +import org.apache.parquet.crypto.InternalColumnDecryptionSetup; +import org.apache.parquet.crypto.InternalFileDecryptor; +import org.apache.parquet.crypto.ModuleCipherFactory; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.CompressionCodec; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -35,6 +42,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetCompressionUtils.decompress; import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages; import static java.util.Objects.requireNonNull; @@ -50,13 +58,20 @@ public final class PageReader private boolean dictionaryAlreadyRead; private int dataPageReadCount; + // For parquet modular encryption + private final Optional blockDecryptor; + private byte[] dataPageAdditionalAuthenticationData; + private byte[] dictionaryPageAdditionalAuthenticationData; + public static PageReader createPageReader( ParquetDataSourceId dataSourceId, ChunkedInputStream columnChunk, ColumnChunkMetaData metadata, ColumnDescriptor columnDescriptor, @Nullable OffsetIndex offsetIndex, - Optional fileCreatedBy) + Optional fileCreatedBy, + Optional fileDecryptor, + int rowGroupOrdinal) { // Parquet schema may specify a column definition as OPTIONAL even though there are no nulls in the actual data. // Row-group column statistics can be used to identify such cases and switch to faster non-nullable read @@ -64,20 +79,43 @@ public static PageReader createPageReader( Statistics columnStatistics = metadata.getStatistics(); boolean hasNoNulls = columnStatistics != null && columnStatistics.getNumNulls() == 0; boolean hasOnlyDictionaryEncodedPages = isOnlyDictionaryEncodingPages(metadata); + + Optional dataDecryptor = Optional.empty(); + byte[] dataPageAdditionalAuthenticationData = null; + byte[] dictionaryPageAdditionalAuthenticationData = null; + int columnOrdinal = -1; + + if (fileDecryptor.isPresent()) { + InternalFileDecryptor fileDecryptorValue = fileDecryptor.get(); + byte[] fileAdditionalAuthenticationData = fileDecryptorValue.getFileAAD(); + + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptorValue.getColumnSetup(columnPath); + dataDecryptor = getDataDecryptor(columnDecryptionSetup); + + columnOrdinal = columnDecryptionSetup.getOrdinal(); + dataPageAdditionalAuthenticationData = AesCipher.createModuleAAD(fileAdditionalAuthenticationData, ModuleCipherFactory.ModuleType.DataPage, rowGroupOrdinal, columnOrdinal, 0); + dictionaryPageAdditionalAuthenticationData = AesCipher.createModuleAAD(fileAdditionalAuthenticationData, ModuleCipherFactory.ModuleType.DictionaryPage, rowGroupOrdinal, columnOrdinal, -1); + } + ParquetColumnChunkIterator compressedPages = new ParquetColumnChunkIterator( dataSourceId, fileCreatedBy, columnDescriptor, metadata, columnChunk, - offsetIndex); + offsetIndex, + fileDecryptor, + rowGroupOrdinal, + columnOrdinal); return new PageReader( dataSourceId, metadata.getCodec().getParquetCompressionCodec(), compressedPages, hasOnlyDictionaryEncodedPages, - hasNoNulls); + hasNoNulls, + dataDecryptor, dataPageAdditionalAuthenticationData, dictionaryPageAdditionalAuthenticationData); } @VisibleForTesting @@ -86,13 +124,20 @@ public PageReader( CompressionCodec codec, Iterator compressedPages, boolean hasOnlyDictionaryEncodedPages, - boolean hasNoNulls) + boolean hasNoNulls, + Optional blockDecryptor, + byte[] dataPageAdditionalAuthenticationData, + byte[] dictionaryPageAdditionalAuthenticationData) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.codec = codec; this.compressedPages = Iterators.peekingIterator(compressedPages); this.hasOnlyDictionaryEncodedPages = hasOnlyDictionaryEncodedPages; this.hasNoNulls = hasNoNulls; + + this.blockDecryptor = blockDecryptor; + this.dataPageAdditionalAuthenticationData = dataPageAdditionalAuthenticationData; + this.dictionaryPageAdditionalAuthenticationData = dictionaryPageAdditionalAuthenticationData; } public boolean hasNoNulls() @@ -110,16 +155,23 @@ public DataPage readPage() if (!compressedPages.hasNext()) { return null; } + + if (blockDecryptor.isPresent()) { + AesCipher.quickUpdatePageAAD(dataPageAdditionalAuthenticationData, dataPageReadCount); + } + Page compressedPage = compressedPages.next(); checkState(compressedPage instanceof DataPage, "Found page %s instead of a DataPage", compressedPage); dataPageReadCount++; try { + Slice slice = decryptSliceIfNeeded(((DataPage) compressedPage).getSlice(), dataPageAdditionalAuthenticationData); + if (compressedPage instanceof DataPageV1 dataPageV1) { if (!arePagesCompressed()) { return dataPageV1; } return new DataPageV1( - decompress(dataSourceId, codec, dataPageV1.getSlice(), dataPageV1.getUncompressedSize()), + decompress(dataSourceId, codec, slice, dataPageV1.getUncompressedSize()), dataPageV1.getValueCount(), dataPageV1.getUncompressedSize(), dataPageV1.getFirstRowIndex(), @@ -141,7 +193,7 @@ public DataPage readPage() dataPageV2.getRepetitionLevels(), dataPageV2.getDefinitionLevels(), dataPageV2.getDataEncoding(), - decompress(dataSourceId, codec, dataPageV2.getSlice(), uncompressedSize), + decompress(dataSourceId, codec, slice, uncompressedSize), dataPageV2.getUncompressedSize(), dataPageV2.getFirstRowIndex(), dataPageV2.getStatistics(), @@ -162,8 +214,9 @@ public DictionaryPage readDictionaryPage() } try { DictionaryPage compressedDictionaryPage = (DictionaryPage) compressedPages.next(); + Slice slice = decryptSliceIfNeeded(compressedDictionaryPage.getSlice(), dictionaryPageAdditionalAuthenticationData); return new DictionaryPage( - decompress(dataSourceId, codec, compressedDictionaryPage.getSlice(), compressedDictionaryPage.getUncompressedSize()), + decompress(dataSourceId, codec, slice, compressedDictionaryPage.getUncompressedSize()), compressedDictionaryPage.getDictionarySize(), compressedDictionaryPage.getEncoding()); } @@ -199,4 +252,23 @@ private void verifyDictionaryPageRead() { checkArgument(dictionaryAlreadyRead, "Dictionary has to be read first"); } + + // additional authenticated data for AES cipher + private Slice decryptSliceIfNeeded(Slice slice, byte[] additionalAuthenticationData) + throws IOException + { + if (blockDecryptor.isEmpty()) { + return slice; + } + byte[] plainText = blockDecryptor.get().decrypt(slice.getBytes(), additionalAuthenticationData); + return wrappedBuffer(plainText); + } + + private static Optional getDataDecryptor(InternalColumnDecryptionSetup columnDecryptionSetup) + { + if (columnDecryptionSetup == null || columnDecryptionSetup.getDataDecryptor() == null) { + return Optional.empty(); + } + return Optional.of(columnDecryptionSetup.getDataDecryptor()); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java index 577e5cb602fa..2a3d57699710 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java @@ -22,18 +22,26 @@ import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.crypto.AesCipher; +import org.apache.parquet.crypto.InternalColumnDecryptionSetup; +import org.apache.parquet.crypto.InternalFileDecryptor; +import org.apache.parquet.crypto.ModuleCipherFactory; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.PageHeader; import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; import java.util.Iterator; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import static com.google.common.base.Preconditions.checkState; import static io.trino.parquet.ParquetTypeUtils.getParquetEncoding; @@ -51,6 +59,12 @@ public final class ParquetColumnChunkIterator private long valueCount; private int dataPageCount; + private final byte[] dataPageHeaderAdditionalAuthenticationData; + private final byte[] fileAAD; + private final BlockCipher.Decryptor headerBlockDecryptor; + private final int rowGroupOrdinal; + private final int columnOrdinal; + private Page dictionaryPage; public ParquetColumnChunkIterator( ParquetDataSourceId dataSourceId, @@ -58,7 +72,10 @@ public ParquetColumnChunkIterator( ColumnDescriptor descriptor, ColumnChunkMetaData metadata, ChunkedInputStream input, - @Nullable OffsetIndex offsetIndex) + @Nullable OffsetIndex offsetIndex, + Optional fileDecryptor, + int rowGroupOrdinal, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -66,6 +83,27 @@ public ParquetColumnChunkIterator( this.metadata = requireNonNull(metadata, "metadata is null"); this.input = requireNonNull(input, "input is null"); this.offsetIndex = offsetIndex; + + this.rowGroupOrdinal = rowGroupOrdinal; + this.columnOrdinal = columnOrdinal; + + if (fileDecryptor.isPresent()) { + ColumnPath columnPath = ColumnPath.get(descriptor.getPath()); + InternalColumnDecryptionSetup columnDecryptionSetup = fileDecryptor.get().getColumnSetup(columnPath); + headerBlockDecryptor = columnDecryptionSetup.getMetaDataDecryptor(); + if (headerBlockDecryptor != null) { + this.dataPageHeaderAdditionalAuthenticationData = AesCipher.createModuleAAD(fileDecryptor.get().getFileAAD(), ModuleCipherFactory.ModuleType.DataPageHeader, rowGroupOrdinal, columnOrdinal, dataPageCount); + } + else { + this.dataPageHeaderAdditionalAuthenticationData = null; + } + fileAAD = fileDecryptor.get().getFileAAD(); + } + else { + this.headerBlockDecryptor = null; + this.dataPageHeaderAdditionalAuthenticationData = null; + this.fileAAD = null; + } } @Override @@ -80,7 +118,18 @@ public Page next() checkState(hasNext(), "No more data left to read in column (%s), metadata (%s), valueCount %s, dataPageCount %s", descriptor, metadata, valueCount, dataPageCount); try { - PageHeader pageHeader = readPageHeader(); + byte[] pageHeaderAdditionalAuthenticationData = dataPageHeaderAdditionalAuthenticationData; + if (headerBlockDecryptor != null) { + // Important: this verifies file integrity (makes sure dictionary page had not been removed) + if (dictionaryPage == null && hasDictionaryPage(metadata)) { + pageHeaderAdditionalAuthenticationData = AesCipher.createModuleAAD(fileAAD, ModuleCipherFactory.ModuleType.DictionaryPageHeader, rowGroupOrdinal, columnOrdinal, -1); + } + else { + AesCipher.quickUpdatePageAAD(dataPageHeaderAdditionalAuthenticationData, dataPageCount); + } + } + PageHeader pageHeader = Util.readPageHeader(input, headerBlockDecryptor, pageHeaderAdditionalAuthenticationData); + int uncompressedPageSize = pageHeader.getUncompressed_page_size(); int compressedPageSize = pageHeader.getCompressed_page_size(); Page result = null; @@ -90,6 +139,7 @@ public Page next() throw new ParquetCorruptionException(dataSourceId, "Column (%s) has a dictionary page after the first position in column chunk", descriptor); } result = readDictionaryPage(pageHeader, pageHeader.getUncompressed_page_size(), pageHeader.getCompressed_page_size()); + this.dictionaryPage = result; break; case DATA_PAGE: result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); @@ -110,10 +160,15 @@ public Page next() } } - private PageHeader readPageHeader() - throws IOException + private boolean hasDictionaryPage(ColumnChunkMetaData columnChunkMetaData) { - return Util.readPageHeader(input); + EncodingStats stats = columnChunkMetaData.getEncodingStats(); + if (stats != null) { + return stats.hasDictionaryPages() && stats.hasDictionaryEncodedPages(); + } + + Set encodings = columnChunkMetaData.getEncodings(); + return encodings.contains(Encoding.PLAIN_DICTIONARY) || encodings.contains(Encoding.RLE_DICTIONARY); } private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoFar) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index aa9633f2763b..b11fdb3a3e99 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -47,6 +47,7 @@ import io.trino.spi.type.Type; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.filter2.compat.FilterCompat; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.hadoop.metadata.BlockMetaData; @@ -127,6 +128,7 @@ public class ParquetReader private final FilteredRowRanges[] blockRowRanges; private final ParquetBlockFactory blockFactory; private final Map> codecMetrics; + private final Optional fileDecryptor; private long columnIndexRowsFiltered = -1; @@ -140,7 +142,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - Optional writeValidation) + Optional writeValidation, + Optional fileDecryptor) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -156,6 +159,7 @@ public ParquetReader( this.maxBatchSize = options.getMaxReadBlockRowCount(); this.columnReaders = new HashMap<>(); this.maxBytesPerCell = new HashMap<>(); + this.fileDecryptor = requireNonNull(fileDecryptor, "fileDecryptor is null"); this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); validateWrite( @@ -454,8 +458,14 @@ private ColumnChunk readPrimitive(PrimitiveField field) offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath()); } ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup)); + + Optional fileDecryptorForPageReader = fileDecryptor; + if (!isColumnEncrypted(fileDecryptor, columnDescriptor)) { + // wyu: fileDecryptor is not required to be passed into the page reader. + fileDecryptorForPageReader = Optional.empty(); + } columnReader.setPageReader( - createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy), + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, fileDecryptorForPageReader, currentRowGroup), Optional.ofNullable(rowRanges)); } ColumnChunk columnChunk = columnReader.readPrimitive(); @@ -477,6 +487,12 @@ public List getColumnFields() return columnFields; } + private static boolean isColumnEncrypted(Optional fileDecryptor, ColumnDescriptor columnDescriptor) + { + ColumnPath columnPath = ColumnPath.get(columnDescriptor.getPath()); + return fileDecryptor.isPresent() && !fileDecryptor.get().plaintextFile() && fileDecryptor.get().getColumnSetup(columnPath).isEncrypted(); + } + public Metrics getMetrics() { ImmutableMap.Builder> metrics = ImmutableMap.>builder() diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 5b813482625f..925dd1f104ad 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -215,7 +215,7 @@ public void validate(ParquetDataSource input) checkState(validationBuilder.isPresent(), "validation is not enabled"); ParquetWriteValidation writeValidation = validationBuilder.get().build(); try { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation), Optional.empty()); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page @@ -270,7 +270,8 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new RuntimeException(exception); }, Optional.empty(), - Optional.of(writeValidation)); + Optional.of(writeValidation), + Optional.empty()); } private void recordValidation(Consumer task) diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnChunkMetaData.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..1a0a936ffe82 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnChunkMetaData.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto; + +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import static java.util.Objects.requireNonNull; + +public class HiddenColumnChunkMetaData + extends ColumnChunkMetaData +{ + private final ColumnPath path; + private final String filePath; + + public HiddenColumnChunkMetaData(ColumnPath path, String filePath) + { + super(null, null); + this.path = requireNonNull(path, "path should not be null"); + this.filePath = requireNonNull(filePath, "filePath should not be null"); + } + + @Override + public long getFirstDataPageOffset() + { + throw new HiddenColumnException(path.toArray(), filePath); + } + + @Override + public long getDictionaryPageOffset() + { + throw new HiddenColumnException(path.toArray(), filePath); + } + + @Override + public long getValueCount() + { + throw new HiddenColumnException(path.toArray(), this.filePath); + } + + @Override + public long getTotalUncompressedSize() + { + throw new HiddenColumnException(path.toArray(), filePath); + } + + @Override + public long getTotalSize() + { + throw new HiddenColumnException(path.toArray(), filePath); + } + + @Override + public Statistics getStatistics() + { + throw new HiddenColumnException(path.toArray(), filePath); + } + + public static boolean isHiddenColumn(ColumnChunkMetaData column) + { + return column instanceof HiddenColumnChunkMetaData; + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnException.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnException.java new file mode 100644 index 000000000000..88de60d2e120 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/HiddenColumnException.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto; + +import org.apache.parquet.ParquetRuntimeException; + +import java.util.Arrays; + +public class HiddenColumnException + extends ParquetRuntimeException +{ + private static final long serialVersionUID = 1L; + private final String column; + + public HiddenColumnException(String[] columnPath, String filePath) + { + super(String.format("User does not have access to the encryption key for encrypted column = %s for file: %s", Arrays.toString(columnPath), filePath)); + // We have to duplicate the toString() call because super() won't allow anything else before it + column = Arrays.toString(columnPath); + } + + public String getColumn() + { + return column; + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoCryptoConfigurationUtil.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoCryptoConfigurationUtil.java new file mode 100644 index 000000000000..bf075bc00ff6 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoCryptoConfigurationUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto; + +import io.airlift.log.Logger; + +public class TrinoCryptoConfigurationUtil +{ + public static final Logger LOG = Logger.get(TrinoCryptoConfigurationUtil.class); + + private TrinoCryptoConfigurationUtil() + { + } + + public static Class getClassFromConfig(String className, Class assignableFrom) + { + try { + final Class foundClass = Class.forName(className); + if (!assignableFrom.isAssignableFrom(foundClass)) { + LOG.warn("class " + className + " is not a subclass of " + assignableFrom.getCanonicalName()); + return null; + } + return foundClass; + } + catch (ClassNotFoundException e) { + LOG.warn("could not instantiate class " + className, e); + return null; + } + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoDecryptionPropertiesFactory.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoDecryptionPropertiesFactory.java new file mode 100644 index 000000000000..7303b1fbd9e1 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/TrinoDecryptionPropertiesFactory.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto; + +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; + +public interface TrinoDecryptionPropertiesFactory +{ + // TODO(wyu): maybe create a dedicate config class in org.apache.parquet and convert ParquetReaderOptions to this class? + FileDecryptionProperties getFileDecryptionProperties(io.trino.parquet.ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException; +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java new file mode 100644 index 000000000000..c25faec50ac2 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoFileKeyUnwrapper.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.parquet.crypto.keytools; + +import com.google.common.base.Strings; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +import org.apache.parquet.crypto.keytools.TrinoKeyToolkit.KeyWithMasterID; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +import static org.apache.parquet.crypto.keytools.TrinoKeyToolkit.KEK_READ_CACHE_PER_TOKEN; +import static org.apache.parquet.crypto.keytools.TrinoKeyToolkit.KMS_CLIENT_CACHE_PER_TOKEN; + +public class TrinoFileKeyUnwrapper + implements DecryptionKeyRetriever +{ + private static final Logger LOG = LoggerFactory.getLogger(TrinoFileKeyUnwrapper.class); + + //A map of KEK_ID -> KEK bytes, for the current token + private final ConcurrentMap kekPerKekID; + private final Location parquetFilePath; + // TODO(wyu): shall we get it from Location or File + private final TrinoFileSystem trinoFileSystem; + private final String accessToken; + private final long cacheEntryLifetime; + private final ParquetReaderOptions parquetReaderOptions; + private TrinoKeyToolkit.TrinoKmsClientAndDetails kmsClientAndDetails; + private TrinoHadoopFSKeyMaterialStore keyMaterialStore; + private boolean checkedKeyMaterialInternalStorage; + + TrinoFileKeyUnwrapper(ParquetReaderOptions conf, Location filePath, TrinoFileSystem trinoFileSystem) + { + this.trinoFileSystem = trinoFileSystem; + this.parquetReaderOptions = conf; + this.parquetFilePath = filePath; + this.cacheEntryLifetime = 1000L * conf.getEncryptionCacheLifetimeSeconds(); + this.accessToken = conf.getEncryptionKeyAccessToken(); + this.kmsClientAndDetails = null; + this.keyMaterialStore = null; + this.checkedKeyMaterialInternalStorage = false; + + // Check cache upon each file reading (clean once in cacheEntryLifetime) + KMS_CLIENT_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + KEK_READ_CACHE_PER_TOKEN.checkCacheForExpiredTokens(cacheEntryLifetime); + kekPerKekID = KEK_READ_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + if (LOG.isDebugEnabled()) { + LOG.debug("Creating file key unwrapper. KeyMaterialStore: {}; token snippet: {}", + keyMaterialStore, TrinoKeyToolkit.formatTokenForLog(accessToken)); + } + } + + @Override + public byte[] getKey(byte[] keyMetadataBytes) + { + KeyMetadata keyMetadata = KeyMetadata.parse(keyMetadataBytes); + + if (!checkedKeyMaterialInternalStorage) { + if (!keyMetadata.keyMaterialStoredInternally()) { + keyMaterialStore = new TrinoHadoopFSKeyMaterialStore(trinoFileSystem, parquetFilePath, false); + } + checkedKeyMaterialInternalStorage = true; + } + + KeyMaterial keyMaterial; + if (keyMetadata.keyMaterialStoredInternally()) { + // Internal key material storage: key material is inside key metadata + keyMaterial = keyMetadata.getKeyMaterial(); + } + else { + // External key material storage: key metadata contains a reference to a key in the material store + String keyIDinFile = keyMetadata.getKeyReference(); + String keyMaterialString = keyMaterialStore.getKeyMaterial(keyIDinFile); + if (null == keyMaterialString) { + throw new ParquetCryptoRuntimeException("Null key material for keyIDinFile: " + keyIDinFile); + } + keyMaterial = KeyMaterial.parse(keyMaterialString); + } + + return getDEKandMasterID(keyMaterial).getDataKey(); + } + + KeyWithMasterID getDEKandMasterID(KeyMaterial keyMaterial) + { + if (null == kmsClientAndDetails) { + kmsClientAndDetails = getKmsClientFromConfigOrKeyMaterial(keyMaterial); + } + + boolean doubleWrapping = keyMaterial.isDoubleWrapped(); + String masterKeyID = keyMaterial.getMasterKeyID(); + String encodedWrappedDEK = keyMaterial.getWrappedDEK(); + + byte[] dataKey; + TrinoKmsClient kmsClient = kmsClientAndDetails.getKmsClient(); + if (!doubleWrapping) { + dataKey = kmsClient.unwrapKey(encodedWrappedDEK, masterKeyID); + } + else { + // Get KEK + String encodedKekID = keyMaterial.getKekID(); + String encodedWrappedKEK = keyMaterial.getWrappedKEK(); + + byte[] kekBytes = kekPerKekID.computeIfAbsent(encodedKekID, + (k) -> kmsClient.unwrapKey(encodedWrappedKEK, masterKeyID)); + + if (null == kekBytes) { + throw new ParquetCryptoRuntimeException("Null KEK, after unwrapping in KMS with master key " + masterKeyID); + } + + // Decrypt the data key + byte[] aad = Base64.getDecoder().decode(encodedKekID); + dataKey = TrinoKeyToolkit.decryptKeyLocally(encodedWrappedDEK, kekBytes, aad); + } + + return new KeyWithMasterID(dataKey, masterKeyID); + } + + TrinoKeyToolkit.TrinoKmsClientAndDetails getKmsClientFromConfigOrKeyMaterial(KeyMaterial keyMaterial) + { + String kmsInstanceID = this.parquetReaderOptions.getEncryptionKmsInstanceId(); + if (Strings.isNullOrEmpty(kmsInstanceID)) { + kmsInstanceID = keyMaterial.getKmsInstanceID(); + if (null == kmsInstanceID) { + throw new ParquetCryptoRuntimeException("KMS instance ID is missing both in properties and file key material"); + } + } + + String kmsInstanceURL = this.parquetReaderOptions.getEncryptionKmsInstanceUrl(); + if (Strings.isNullOrEmpty(kmsInstanceURL)) { + kmsInstanceURL = keyMaterial.getKmsInstanceURL(); + if (null == kmsInstanceURL) { + throw new ParquetCryptoRuntimeException("KMS instance URL is missing both in properties and file key material"); + } + } + + TrinoKmsClient kmsClient = TrinoKeyToolkit.getKmsClient(kmsInstanceID, kmsInstanceURL, this.parquetReaderOptions, accessToken, cacheEntryLifetime); + if (null == kmsClient) { + throw new ParquetCryptoRuntimeException("KMSClient was not successfully created for reading encrypted data."); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("File unwrapper - KmsClient: {}; InstanceId: {}; InstanceURL: {}", kmsClient, kmsInstanceID, kmsInstanceURL); + } + return new TrinoKeyToolkit.TrinoKmsClientAndDetails(kmsClient, kmsInstanceID, kmsInstanceURL); + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java new file mode 100644 index 000000000000..45bba144893d --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoHadoopFSKeyMaterialStore.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto.keytools; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Map; + +public class TrinoHadoopFSKeyMaterialStore +{ + public static final String KEY_MATERIAL_FILE_PREFIX = "_KEY_MATERIAL_FOR_"; + public static final String TEMP_FILE_PREFIX = "_TMP"; + public static final String KEY_MATERIAL_FILE_SUFFFIX = ".json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private TrinoFileSystem trinoFileSystem; + private Map keyMaterialMap; + private Location keyMaterialFile; + + TrinoHadoopFSKeyMaterialStore(TrinoFileSystem trinoFileSystem, Location parquetFilePath, boolean tempStore) + { + this.trinoFileSystem = trinoFileSystem; + String fullPrefix = (tempStore ? TEMP_FILE_PREFIX : ""); + fullPrefix += KEY_MATERIAL_FILE_PREFIX; + keyMaterialFile = parquetFilePath.parentDirectory().appendSuffix( + fullPrefix + parquetFilePath.path() + KEY_MATERIAL_FILE_SUFFFIX); + } + + public String getKeyMaterial(String keyIDInFile) + throws ParquetCryptoRuntimeException + { + if (null == keyMaterialMap) { + loadKeyMaterialMap(); + } + return keyMaterialMap.get(keyIDInFile); + } + + private void loadKeyMaterialMap() + { + TrinoInputFile inputfile = trinoFileSystem.newInputFile(keyMaterialFile); + try (TrinoInputStream keyMaterialStream = inputfile.newStream()) { + JsonNode keyMaterialJson = objectMapper.readTree(keyMaterialStream); + keyMaterialMap = objectMapper.readValue(keyMaterialJson.traverse(), + new TypeReference>() {}); + } + catch (FileNotFoundException e) { + throw new ParquetCryptoRuntimeException("External key material not found at " + keyMaterialFile, e); + } + catch (IOException e) { + throw new ParquetCryptoRuntimeException("Failed to get key material from " + keyMaterialFile, e); + } + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKeyToolkit.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKeyToolkit.java new file mode 100644 index 000000000000..ac081e0516b7 --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKeyToolkit.java @@ -0,0 +1,220 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import org.apache.parquet.crypto.AesGcmDecryptor; +import org.apache.parquet.crypto.AesMode; +import org.apache.parquet.crypto.ModuleCipherFactory; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +import org.apache.parquet.crypto.TrinoCryptoConfigurationUtil; + +import java.util.Base64; +import java.util.concurrent.ConcurrentMap; + +public class TrinoKeyToolkit +{ + public static final long CACHE_LIFETIME_DEFAULT_SECONDS = 10 * 60; // 10 minutes + + // KMS client two level cache: token -> KMSInstanceId -> KmsClient + static final TwoLevelCacheWithExpiration KMS_CLIENT_CACHE_PER_TOKEN = + KmsClientCache.INSTANCE.getCache(); + + // KEK two level cache for unwrapping: token -> KEK_ID -> KEK bytes + static final TwoLevelCacheWithExpiration KEK_READ_CACHE_PER_TOKEN = + KEKReadCache.INSTANCE.getCache(); + + private TrinoKeyToolkit() + { + } + + private enum KmsClientCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + private enum KEKReadCache + { + INSTANCE; + private final TwoLevelCacheWithExpiration cache = + new TwoLevelCacheWithExpiration<>(); + + private TwoLevelCacheWithExpiration getCache() + { + return cache; + } + } + + static String formatTokenForLog(String accessToken) + { + int maxTokenDisplayLength = 5; + if (accessToken.length() <= maxTokenDisplayLength) { + return accessToken; + } + return accessToken.substring(accessToken.length() - maxTokenDisplayLength); + } + + static class KeyWithMasterID + { + private final byte[] keyBytes; + private final String masterID; + + KeyWithMasterID(byte[] keyBytes, String masterID) + { + this.keyBytes = keyBytes; + this.masterID = masterID; + } + + byte[] getDataKey() + { + return keyBytes; + } + + String getMasterID() + { + return masterID; + } + } + + static class KeyEncryptionKey + { + private final byte[] kekBytes; + private final byte[] kekID; + private String encodedKekID; + private final String encodedWrappedKEK; + + KeyEncryptionKey(byte[] kekBytes, byte[] kekID, String encodedWrappedKEK) + { + this.kekBytes = kekBytes; + this.kekID = kekID; + this.encodedWrappedKEK = encodedWrappedKEK; + } + + byte[] getBytes() + { + return kekBytes; + } + + byte[] getID() + { + return kekID; + } + + String getEncodedID() + { + if (null == encodedKekID) { + encodedKekID = Base64.getEncoder().encodeToString(kekID); + } + return encodedKekID; + } + + String getEncodedWrappedKEK() + { + return encodedWrappedKEK; + } + } + + /** + * Decrypts encrypted key with "masterKey", using AES-GCM and the "aad" + * + * @param encodedEncryptedKey base64 encoded encrypted key + * @param masterKeyBytes encryption key + * @param aad additional authenticated data + * @return decrypted key + */ + public static byte[] decryptKeyLocally(String encodedEncryptedKey, byte[] masterKeyBytes, byte[] aad) + { + byte[] encryptedKey = Base64.getDecoder().decode(encodedEncryptedKey); + + AesGcmDecryptor keyDecryptor; + + keyDecryptor = (AesGcmDecryptor) ModuleCipherFactory.getDecryptor(AesMode.GCM, masterKeyBytes); + + return keyDecryptor.decrypt(encryptedKey, 0, encryptedKey.length, aad); + } + + static TrinoKmsClient getKmsClient(String kmsInstanceID, String kmsInstanceURL, ParquetReaderOptions trinoParquetCryptoConfig, + String accessToken, long cacheEntryLifetime) + { + ConcurrentMap kmsClientPerKmsInstanceCache = + KMS_CLIENT_CACHE_PER_TOKEN.getOrCreateInternalCache(accessToken, cacheEntryLifetime); + + TrinoKmsClient kmsClient = + kmsClientPerKmsInstanceCache.computeIfAbsent(kmsInstanceID, + (k) -> createAndInitKmsClient(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken)); + + return kmsClient; + } + + private static TrinoKmsClient createAndInitKmsClient(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, + String kmsInstanceURL, String accessToken) + { + Class kmsClientClass = null; + TrinoKmsClient kmsClient = null; + + try { + kmsClientClass = TrinoCryptoConfigurationUtil.getClassFromConfig(trinoParquetCryptoConfig.getEncryptionKmsClientClass(), + TrinoKmsClient.class); + + if (null == kmsClientClass) { + throw new ParquetCryptoRuntimeException("Could not find class " + trinoParquetCryptoConfig.getEncryptionKmsClientClass()); + } + kmsClient = (TrinoKmsClient) kmsClientClass.newInstance(); + } + catch (InstantiationException | IllegalAccessException e) { + throw new ParquetCryptoRuntimeException("Could not instantiate KmsClient class: " + + kmsClientClass, e); + } + + kmsClient.initialize(trinoParquetCryptoConfig, kmsInstanceID, kmsInstanceURL, accessToken); + + return kmsClient; + } + + static class TrinoKmsClientAndDetails + { + public TrinoKmsClient getKmsClient() + { + return kmsClient; + } + + private TrinoKmsClient kmsClient; + private String kmsInstanceID; + private String kmsInstanceURL; + + public TrinoKmsClientAndDetails(TrinoKmsClient kmsClient, String kmsInstanceID, String kmsInstanceURL) + { + this.kmsClient = kmsClient; + this.kmsInstanceID = kmsInstanceID; + this.kmsInstanceURL = kmsInstanceURL; + } + + public String getKmsInstanceID() + { + return kmsInstanceID; + } + + public String getKmsInstanceURL() + { + return kmsInstanceURL; + } + } +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKmsClient.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKmsClient.java new file mode 100644 index 000000000000..35c16384c9db --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoKmsClient.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto.keytools; + +import io.trino.parquet.ParquetReaderOptions; +import org.apache.parquet.crypto.KeyAccessDeniedException; + +public interface TrinoKmsClient +{ + void initialize(ParquetReaderOptions trinoParquetCryptoConfig, String kmsInstanceID, String kmsInstanceURL, String accessToken) + throws KeyAccessDeniedException; + + String wrapKey(byte[] keyBytes, String masterKeyIdentifier) + throws KeyAccessDeniedException; + + byte[] unwrapKey(String wrappedKey, String masterKeyIdentifier) + throws KeyAccessDeniedException; +} diff --git a/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java new file mode 100644 index 000000000000..4fcd19c0cc7a --- /dev/null +++ b/lib/trino-parquet/src/main/java/org/apache/parquet/crypto/keytools/TrinoPropertiesDrivenCryptoFactory.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.crypto.keytools; + +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +import org.apache.parquet.crypto.TrinoDecryptionPropertiesFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TrinoPropertiesDrivenCryptoFactory + implements TrinoDecryptionPropertiesFactory +{ + private static final Logger LOG = LoggerFactory.getLogger(PropertiesDrivenCryptoFactory.class); + + @Override + public FileDecryptionProperties getFileDecryptionProperties(ParquetReaderOptions parquetReaderOptions, Location filePath, TrinoFileSystem trinoFileSystem) + throws ParquetCryptoRuntimeException + { + DecryptionKeyRetriever keyRetriever = new TrinoFileKeyUnwrapper(parquetReaderOptions, filePath, trinoFileSystem); + + if (LOG.isDebugEnabled()) { + LOG.debug("File decryption properties for {}", filePath); + } + + return FileDecryptionProperties.builder() + .withKeyRetriever(keyRetriever) + .withPlaintextFilesAllowed() + .build(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index 47dcbd151d25..aec9d2495d55 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -132,6 +132,7 @@ public static ParquetReader createParquetReader( return new RuntimeException(exception); }, Optional.empty(), + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java index 7335d3ef19af..bef4b596e6f0 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java @@ -104,7 +104,7 @@ public int read() throws IOException { ColumnReader columnReader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false, Optional.empty(), null, null); columnReader.setPageReader(pageReader, Optional.empty()); int rowsRead = 0; while (rowsRead < dataPositions) { diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 22c57a04dc71..4abfadd02049 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -284,8 +284,8 @@ public String toString() private static ColumnReaderInput[] getColumnReaderInputs(ColumnReaderProvider columnReaderProvider) { Object[][] definitionLevelsProviders = Arrays.stream(DefinitionLevelsProvider.ofDefinitionLevel( - columnReaderProvider.getField().getDefinitionLevel(), - columnReaderProvider.getField().isRequired())) + columnReaderProvider.getField().getDefinitionLevel(), + columnReaderProvider.getField().isRequired())) .collect(toDataProvider()); PrimitiveField field = columnReaderProvider.getField(); @@ -563,7 +563,8 @@ else if (dictionaryEncoding == DictionaryEncoding.MIXED) { UNCOMPRESSED, inputPages.iterator(), dictionaryEncoding == DictionaryEncoding.ALL || (dictionaryEncoding == DictionaryEncoding.MIXED && testingPages.size() == 1), - false); + false, + Optional.empty(), null, null); } private static List createDataPages(List testingPages, ValuesWriter encoder, int maxDef, boolean required) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 5712562b7a72..0b382e4818f6 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -698,7 +698,7 @@ protected static PageReader getPageReaderMock(List dataPages, @Nullabl return ((DataPageV2) page).getDataEncoding(); }) .allMatch(encoding -> encoding == PLAIN_DICTIONARY || encoding == RLE_DICTIONARY), - hasNoNulls); + hasNoNulls, Optional.empty(), null, null); } private DataPage createDataPage(DataPageVersion version, ParquetEncoding encoding, ValuesWriter writer, int valueCount) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptDecryptUtil.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptDecryptUtil.java new file mode 100644 index 000000000000..de77d8d38949 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptDecryptUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.reader; + +import org.apache.parquet.crypto.ColumnEncryptionProperties; +import org.apache.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.FileEncryptionProperties; +import org.apache.parquet.crypto.ParquetCipher; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EncryptDecryptUtil +{ + private EncryptDecryptUtil() + { + } + + private static class DecryptionKeyRetrieverMock + implements DecryptionKeyRetriever + { + private final Map keyMap = new HashMap<>(); + + public DecryptionKeyRetrieverMock putKey(String keyId, byte[] keyBytes) + { + keyMap.put(keyId, keyBytes); + return this; + } + + @Override + public byte[] getKey(byte[] keyMetaData) + { + String keyId = new String(keyMetaData, StandardCharsets.UTF_8); + return keyMap.get(keyId); + } + } + + private static final String FOOTER_KEY_METADATA = "footkey"; + private static final String COL_KEY_METADATA = "col"; + private static final byte[] FOOTER_KEY = {0x01, 0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}; + private static final byte[] FOOTER_KEY_METADATA_BYTES = FOOTER_KEY_METADATA.getBytes(StandardCharsets.UTF_8); + private static final byte[] COL_KEY = {0x02, 0x03, 0x4, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11}; + private static final byte[] COL_KEY_METADATA_BYTES = COL_KEY_METADATA.getBytes(StandardCharsets.UTF_8); + + public static FileDecryptionProperties getFileDecryptionProperties() + { + DecryptionKeyRetrieverMock keyRetriever = new DecryptionKeyRetrieverMock(); + keyRetriever.putKey("footkey", FOOTER_KEY); + keyRetriever.putKey("col", COL_KEY); + keyRetriever.putKey("{\"kmsInstanceURL\":\"DEFAULT\",\"masterKeyID\":\"cd2694cc-41e0-4825-987a-0351d618abd5\",\"wrappedDEK\":\"rRip/ypzXE/RXq/tCfwEy4IstrjZWcisZyGVw9UtsCnSN5qsLdHAKD3fFoAjGqI9SRJxgy4KpU+MdPhU\",\"keyEncryptionKeyID\":\"KEJRXEF5xVuOd0SMSA8eYQ==\",\"doubleWrapping\":true,\"internalStorage\":true,\"isFooterKey\":true,\"keyMaterialType\":\"PKMT1\",\"kmsInstanceID\":\"DEFAULT\",\"wrappedKEK\":\"AQICAHgvfAPsKQ1WxnnLTOM5YtAZF92nPnqAsgAk0UooHpm/YwFMjl38sjS20MxvjXwO/qatAAAAfjB8BgkqhkiG9w0BBwagbzBtAgEAMGgGCSqGSIb3DQEHATAeBglghkgBZQMEAS4wEQQMTnsRbbkC1nljodRzAgEQgDvAFX5/JLGHSf37jizfbtstsSzKSSk0b2K69ET3wgMEjbxer7XPGUKZcBawWJ7pm6Y0pw5o6P6OCsunSw==\"}", FOOTER_KEY); + return FileDecryptionProperties.builder().withPlaintextFilesAllowed().withKeyRetriever(keyRetriever).build(); + } + + public static FileEncryptionProperties getFileEncryptionProperties(List encryptColumns, ParquetCipher cipher, Boolean encryptFooter) + { + if (encryptColumns.size() == 0) { + return null; + } + + Map columnPropertyMap = new HashMap<>(); + for (String encryptColumn : encryptColumns) { + ColumnPath columnPath = ColumnPath.fromDotString(encryptColumn); + ColumnEncryptionProperties columnEncryptionProperties = ColumnEncryptionProperties.builder(columnPath) + .withKey(COL_KEY) + .withKeyMetaData(COL_KEY_METADATA_BYTES) + .build(); + columnPropertyMap.put(columnPath, columnEncryptionProperties); + } + + FileEncryptionProperties.Builder encryptionPropertiesBuilder = + FileEncryptionProperties.builder(FOOTER_KEY) + .withFooterKeyMetadata(FOOTER_KEY_METADATA_BYTES) + .withAlgorithm(cipher) + .withEncryptedColumns(columnPropertyMap); + + if (!encryptFooter) { + encryptionPropertiesBuilder.withPlaintextFooter(); + } + + return encryptionPropertiesBuilder.build(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java new file mode 100644 index 000000000000..d7b9f429e094 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFile.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.reader; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.example.data.simple.SimpleGroup; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; + +public class EncryptionTestFile +{ + private final String fileName; + private final SimpleGroup[] fileContent; + + public EncryptionTestFile(String fileName, SimpleGroup[] fileContent) + { + checkArgument(!isNullOrEmpty(fileName), "file name cannot be null or empty"); + this.fileName = fileName; + checkArgument(fileContent != null && fileContent.length > 0, "file content cannot be null or empty"); + this.fileContent = fileContent; + } + + public String getFileName() + { + return this.fileName; + } + + public SimpleGroup[] getFileContent() + { + return fileContent; + } + + public long getFileSize() + throws IOException + { + Path path = new Path(fileName); + return path.getFileSystem(new Configuration()).getFileStatus(path).getLen(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFileBuilder.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFileBuilder.java new file mode 100644 index 000000000000..1ac34b5826c9 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/EncryptionTestFileBuilder.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.reader; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.crypto.FileEncryptionProperties; +import org.apache.parquet.crypto.ParquetCipher; +import org.apache.parquet.example.data.Group; +import org.apache.parquet.example.data.simple.SimpleGroup; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.example.ExampleParquetWriter; +import org.apache.parquet.hadoop.example.GroupWriteSupport; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; + +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; + +public class EncryptionTestFileBuilder +{ + private MessageType schema; + private Configuration conf; + private Map extraMeta = new HashMap<>(); + private int numRecord = 100000; + private ParquetProperties.WriterVersion writerVersion = ParquetProperties.WriterVersion.PARQUET_1_0; + private int pageSize = ParquetProperties.DEFAULT_PAGE_SIZE; + private String codec = "ZSTD"; + private String[] encryptColumns = {}; + private ParquetCipher cipher = ParquetCipher.AES_GCM_V1; + private Boolean footerEncryption = false; + + public EncryptionTestFileBuilder(Configuration conf, MessageType schema) + { + this.conf = conf; + this.schema = schema; + conf.set(GroupWriteSupport.PARQUET_EXAMPLE_SCHEMA, schema.toString()); + } + + public EncryptionTestFileBuilder withNumRecord(int numRecord) + { + this.numRecord = numRecord; + return this; + } + + public EncryptionTestFileBuilder withEncrytionAlgorithm(ParquetCipher cipher) + { + this.cipher = cipher; + return this; + } + + public EncryptionTestFileBuilder withExtraMeta(Map extraMeta) + { + this.extraMeta = extraMeta; + return this; + } + + public EncryptionTestFileBuilder withWriterVersion(ParquetProperties.WriterVersion writerVersion) + { + this.writerVersion = writerVersion; + return this; + } + + public EncryptionTestFileBuilder withPageSize(int pageSize) + { + this.pageSize = pageSize; + return this; + } + + public EncryptionTestFileBuilder withCodec(String codec) + { + this.codec = codec; + return this; + } + + public EncryptionTestFileBuilder withEncryptColumns(String[] encryptColumns) + { + this.encryptColumns = encryptColumns; + return this; + } + + public EncryptionTestFileBuilder withFooterEncryption() + { + this.footerEncryption = true; + return this; + } + + public EncryptionTestFile build() + throws IOException + { + String fileName = createTempFile("test"); + SimpleGroup[] fileContent = createFileContent(schema); + FileEncryptionProperties encryptionProperties = EncryptDecryptUtil.getFileEncryptionProperties(Arrays.asList(encryptColumns), cipher, footerEncryption); + ExampleParquetWriter.Builder builder = ExampleParquetWriter.builder(new Path(fileName)) + .withConf(conf) + .withWriterVersion(writerVersion) + .withExtraMetaData(extraMeta) + .withValidation(true) + .withPageSize(pageSize) + .withEncryption(encryptionProperties) + .withCompressionCodec(CompressionCodecName.valueOf(codec)); + try (ParquetWriter writer = builder.build()) { + for (int i = 0; i < fileContent.length; i++) { + writer.write(fileContent[i]); + } + } + return new EncryptionTestFile(fileName, fileContent); + } + + private SimpleGroup[] createFileContent(MessageType schema) + { + SimpleGroup[] simpleGroups = new SimpleGroup[numRecord]; + for (int i = 0; i < simpleGroups.length; i++) { + SimpleGroup g = new SimpleGroup(schema); + for (Type type : schema.getFields()) { + addValueToSimpleGroup(g, type); + } + simpleGroups[i] = g; + } + return simpleGroups; + } + + private void addValueToSimpleGroup(Group g, Type type) + { + if (type.isPrimitive()) { + PrimitiveType primitiveType = (PrimitiveType) type; + if (primitiveType.getPrimitiveTypeName().equals(INT32)) { + g.add(type.getName(), getInt()); + } + else if (primitiveType.getPrimitiveTypeName().equals(INT64)) { + g.add(type.getName(), getLong()); + } + else if (primitiveType.getPrimitiveTypeName().equals(BINARY)) { + g.add(type.getName(), getString()); + } + // Only support 3 types now, more can be added later + } + else { + GroupType groupType = (GroupType) type; + Group parentGroup = g.addGroup(groupType.getName()); + for (Type field : groupType.getFields()) { + addValueToSimpleGroup(parentGroup, field); + } + } + } + + private static long getInt() + { + return ThreadLocalRandom.current().nextInt(10000); + } + + private static long getLong() + { + return ThreadLocalRandom.current().nextLong(100000); + } + + private static String getString() + { + char[] chars = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'x', 'z', 'y'}; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100; i++) { + sb.append(chars[ThreadLocalRandom.current().nextInt(10)]); + } + return sb.toString(); + } + + public static String createTempFile(String prefix) + { + try { + return Files.createTempDirectory(prefix).toAbsolutePath().toString() + "/test.parquet"; + } + catch (IOException e) { + throw new AssertionError("Unable to create temporary file", e); + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java new file mode 100644 index 000000000000..36348955ce88 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockInputStreamTail.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.reader; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.hadoop.fs.FSDataInputStream; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class MockInputStreamTail +{ + public static final int MAX_SUPPORTED_PADDING_BYTES = 64; + private static final int MAXIMUM_READ_LENGTH = Integer.MAX_VALUE - (MAX_SUPPORTED_PADDING_BYTES + 1); + + private final Slice tailSlice; + private final long fileSize; + + private MockInputStreamTail(long fileSize, Slice tailSlice) + { + this.tailSlice = requireNonNull(tailSlice, "tailSlice is null"); + this.fileSize = fileSize; + checkArgument(fileSize >= 0, "fileSize is negative: %s", fileSize); + checkArgument(tailSlice.length() <= fileSize, "length (%s) is greater than fileSize (%s)", tailSlice.length(), fileSize); + } + + public static MockInputStreamTail readTail(String path, long paddedFileSize, FSDataInputStream inputStream, int length) + throws IOException + { + checkArgument(length >= 0, "length is negative: %s", length); + checkArgument(length <= MAXIMUM_READ_LENGTH, "length (%s) exceeds maximum (%s)", length, MAXIMUM_READ_LENGTH); + long readSize = min(paddedFileSize, (length + MAX_SUPPORTED_PADDING_BYTES)); + long position = paddedFileSize - readSize; + // Actual read will be 1 byte larger to ensure we encounter an EOF where expected + byte[] buffer = new byte[toIntExact(readSize + 1)]; + int bytesRead = 0; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + while (bytesRead < buffer.length) { + int n = inputStream.read(buffer, bytesRead, buffer.length - bytesRead); + if (n < 0) { + break; + } + bytesRead += n; + } + } + finally { + inputStream.seek(startPos); + } + if (bytesRead > readSize) { + throw rejectInvalidFileSize(path, paddedFileSize); + } + return new MockInputStreamTail(position + bytesRead, Slices.wrappedBuffer(buffer, max(0, bytesRead - length), min(bytesRead, length))); + } + + public static long readTailForFileSize(String path, long paddedFileSize, FSDataInputStream inputStream) + throws IOException + { + long position = max(paddedFileSize - MAX_SUPPORTED_PADDING_BYTES, 0); + long maxEOFAt = paddedFileSize + 1; + long startPos = inputStream.getPos(); + try { + inputStream.seek(position); + int c; + while (position < maxEOFAt) { + c = inputStream.read(); + if (c < 0) { + return position; + } + position++; + } + throw rejectInvalidFileSize(path, paddedFileSize); + } + finally { + inputStream.seek(startPos); + } + } + + private static IOException rejectInvalidFileSize(String path, long reportedSize) + throws IOException + { + throw new IOException(format("Incorrect file size (%s) for file (end of stream not reached): %s", reportedSize, path)); + } + + public long getFileSize() + { + return fileSize; + } + + public Slice getTailSlice() + { + return tailSlice; + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java new file mode 100644 index 000000000000..ca6c037abe7c --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/MockParquetDataSource.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.parquet.reader; + +import io.trino.parquet.AbstractParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.internal.column.columnindex.ColumnIndex; +import org.apache.parquet.internal.column.columnindex.OffsetIndex; + +import java.io.IOException; +import java.util.Optional; + +public class MockParquetDataSource + extends AbstractParquetDataSource +{ + private final FSDataInputStream inputStream; + private long readTimeNanos; + private long readBytes; + + public MockParquetDataSource( + ParquetDataSourceId id, + long estimatedSize, + FSDataInputStream inputStream) + { + super(id, estimatedSize, new ParquetReaderOptions()); + this.inputStream = inputStream; + } + +// @Override +// protected void readInternal(long position, byte[] buffer, int bufferOffset, int bufferLength) throws IOException { +// +// } + + @Override + protected void readInternal(long position, byte[] buffer, int bufferOffset, int bufferLength) + throws IOException + { + readBytes += bufferLength; + + long start = System.nanoTime(); + try { + inputStream.readFully(position, buffer, bufferOffset, bufferLength); + } + catch (Exception e) { + throw new RuntimeException("Error reading from %s " + getId() + " at position " + position); + } + long currentReadTimeNanos = System.nanoTime() - start; + + readTimeNanos += currentReadTimeNanos; + } + + @Override + public void close() + throws IOException + { + inputStream.close(); + } + + public Optional readColumnIndex(ColumnChunkMetaData column) + throws IOException + { + throw new IOException("Not supported"); + } + + public Optional readOffsetIndex(ColumnChunkMetaData column) + throws IOException + { + throw new IOException("Not supported"); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestEncryption.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestEncryption.java new file mode 100644 index 000000000000..43005c5e1d09 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestEncryption.java @@ -0,0 +1,396 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import com.google.common.collect.ImmutableList; +import io.trino.parquet.Column; +import io.trino.parquet.Field; +import io.trino.parquet.GroupField; +import io.trino.parquet.ParquetCorruptionException; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.PrimitiveField; +import io.trino.spi.ErrorCode; +import io.trino.spi.ErrorType; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.InternalFileDecryptor; +import org.apache.parquet.crypto.ParquetCipher; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.FileMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.io.ColumnIO; +import org.apache.parquet.io.GroupColumnIO; +import org.apache.parquet.io.MessageColumnIO; +import org.apache.parquet.io.PrimitiveColumnIO; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; +import static io.trino.parquet.ParquetTypeUtils.getColumnIO; +import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; +import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.String.format; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; +import static org.joda.time.DateTimeZone.UTC; +import static org.testng.Assert.assertEquals; + +public class TestEncryption +{ + private final Configuration conf = new Configuration(false); + + @Test + public void testBasicDecryption() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + Map extraMetadata = new HashMap() + { + { + put("key1", "value1"); + put("key2", "value2"); + } + }; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(10000) + .withCodec("GZIP") + .withExtraMeta(extraMetadata) + .withPageSize(1000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testAllColumnsDecryption() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"id", "name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(10000) + .withCodec("GZIP") + .withPageSize(1000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testNoColumnsDecryption() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(10000) + .withCodec("GZIP") + .withPageSize(1000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testOneRecord() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(1) + .withCodec("GZIP") + .withPageSize(1000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testMillionRows() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(1000000) + .withCodec("GZIP") + .withPageSize(1000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testPlainTextFooter() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(10000) + .withCodec("SNAPPY") + .withPageSize(1000) + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testLargePageSize() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(100000) + .withCodec("GZIP") + .withPageSize(100000) + .withFooterEncryption() + .build(); + decryptAndValidate(inputFile); + } + + @Test + public void testAesGcmCtr() + throws IOException + { + MessageType schema = createSchema(); + String[] encryptColumns = {"name", "gender"}; + EncryptionTestFile inputFile = new EncryptionTestFileBuilder(conf, schema) + .withEncryptColumns(encryptColumns) + .withNumRecord(100000) + .withCodec("GZIP") + .withPageSize(1000) + .withEncrytionAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .build(); + decryptAndValidate(inputFile); + } + + private MessageType createSchema() + { + return new MessageType("schema", + new PrimitiveType(OPTIONAL, INT64, "id"), + new PrimitiveType(REQUIRED, BINARY, "name"), + new PrimitiveType(OPTIONAL, BINARY, "gender")); + } + + private void decryptAndValidate(EncryptionTestFile inputFile) + throws IOException + { + Path path = new Path(inputFile.getFileName()); + FileSystem fileSystem = path.getFileSystem(conf); + FSDataInputStream inputStream = fileSystem.open(path); + long fileSize = fileSystem.getFileStatus(path).getLen(); + Optional fileDecryptor = createFileDecryptor(); + ParquetDataSource dataSource = new MockParquetDataSource(new ParquetDataSourceId(path.toString()), fileSize, inputStream); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, +// inputFile.getFileSize(), + Optional.empty(), + fileDecryptor); +// .getParquetMetadata(); + FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); + MessageType fileSchema = fileMetaData.getSchema(); + MessageColumnIO messageColumn = getColumnIO(fileSchema, fileSchema); + ParquetReader parquetReader = createParquetReader(parquetMetadata, messageColumn, dataSource, fileDecryptor); + validateFile(parquetReader, messageColumn, inputFile); + } + + private Optional createFileDecryptor() + { + FileDecryptionProperties fileDecryptionProperties = EncryptDecryptUtil.getFileDecryptionProperties(); + if (fileDecryptionProperties != null) { + return Optional.of(new InternalFileDecryptor(fileDecryptionProperties)); + } + return Optional.empty(); + } + + private ParquetReader createParquetReader(ParquetMetadata parquetMetadata, + MessageColumnIO messageColumn, + ParquetDataSource dataSource, + Optional fileDecryptor) + { + ImmutableList.Builder blocks = ImmutableList.builder(); + ImmutableList.Builder blockStarts = ImmutableList.builder(); + + long nextStart = 0; + ImmutableList.Builder rowGroupInfoBuilder = ImmutableList.builder(); + for (BlockMetaData block : parquetMetadata.getBlocks()) { + rowGroupInfoBuilder.add(new RowGroupInfo(block, nextStart, Optional.empty())); + nextStart += block.getRowCount(); + } + + List columns = + ImmutableList.of( + constructColumn(BIGINT, lookupColumnByName(messageColumn, "id")).get(), + constructColumn(VARCHAR, lookupColumnByName(messageColumn, "name")).get(), + constructColumn(VARCHAR, lookupColumnByName(messageColumn, "gender")).get() + ); // ColumnFields + + try { + return new ParquetReader( + Optional.ofNullable(parquetMetadata.getFileMetaData().getCreatedBy()), + columns, + rowGroupInfoBuilder.build(), + dataSource, + UTC, + newSimpleAggregatedMemoryContext(), + new ParquetReaderOptions(), + exception -> handleException(dataSource.getId(), exception), + Optional.empty(), + Optional.empty(), + fileDecryptor); + } + catch (IOException ex) { + throw new IllegalStateException(ex); + } + } + + public static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof TrinoException) { + return (TrinoException) exception; + } + if (exception instanceof ParquetCorruptionException) { + return new TrinoException(() -> new ErrorCode(123, "wyu-code-1", ErrorType.INTERNAL_ERROR), exception); + } + return new TrinoException(() -> new ErrorCode(123, "wyu-code-2", ErrorType.INTERNAL_ERROR), format("Failed to read Parquet file: %s", dataSourceId), exception); + } + + private void validateFile(ParquetReader parquetReader, MessageColumnIO messageColumn, EncryptionTestFile inputFile) + throws IOException + { + int rowIndex = 0; + // nb(wyu): original +// int batchSize = parquetReader.nextBatch(); + io.trino.spi.Page page = parquetReader.nextPage(); + while (page != null) { + validateColumn("id", BIGINT, rowIndex, parquetReader, messageColumn, inputFile); + validateColumn("name", VARCHAR, rowIndex, parquetReader, messageColumn, inputFile); + validateColumn("gender", VARCHAR, rowIndex, parquetReader, messageColumn, inputFile); + rowIndex += page.getPositionCount(); + page = parquetReader.nextPage(); + } + } + + private void validateColumn(String name, Type type, int rowIndex, ParquetReader parquetReader, MessageColumnIO messageColumn, EncryptionTestFile inputFile) + throws IOException + { + Block block = parquetReader.readBlock(constructField(type, lookupColumnByName(messageColumn, name)).orElse(null)); + for (int i = 0; i < block.getPositionCount(); i++) { + if (type.equals(BIGINT)) { + assertEquals(inputFile.getFileContent()[rowIndex++].getLong(name, 0), block.getLong(i, 0)); + } + else if (type.equals(INT32)) { + assertEquals(inputFile.getFileContent()[rowIndex++].getInteger(name, 0), block.getInt(i, 0)); + } + else if (type.equals(VARCHAR)) { + assertEquals(inputFile.getFileContent()[rowIndex++].getString(name, 0), block.getSlice(i, 0, block.getSliceLength(i)).toStringUtf8()); + } + } + } + + private Optional constructColumn(Type type, ColumnIO columnIO) + { + Optional field = constructField(type, columnIO); + if (field.isEmpty()) { + return Optional.empty(); + } + return Optional.of(new Column(columnIO.getName(), field.get())); + } + + private Optional constructField(Type type, ColumnIO columnIO) + { + if (columnIO == null) { + return Optional.empty(); + } + boolean required = columnIO.getType().getRepetition() != OPTIONAL; + int repetitionLevel = columnIO.getRepetitionLevel(); + int definitionLevel = columnIO.getDefinitionLevel(); + if (type instanceof RowType) { + RowType rowType = (RowType) type; + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); + List fields = rowType.getFields(); + boolean structHasParameters = false; + for (int i = 0; i < fields.size(); i++) { + RowType.Field rowField = fields.get(i); + String name = rowField.getName().get().toLowerCase(Locale.ENGLISH); + Optional field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name)); + structHasParameters |= field.isPresent(); + fieldsBuilder.add(field); + } + if (structHasParameters) { + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build())); + } + return Optional.empty(); + } + if (type instanceof MapType) { + MapType mapType = (MapType) type; + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); + Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); + } + if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType) type; + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + if (groupColumnIO.getChildrenCount() != 1) { + return Optional.empty(); + } + Optional field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0))); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); + } + PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; + ColumnDescriptor column = new ColumnDescriptor(primitiveColumnIO.getColumnDescriptor().getPath(), columnIO.getType().asPrimitiveType(), repetitionLevel, definitionLevel); + return Optional.of(new PrimitiveField(type, required, column, primitiveColumnIO.getId())); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java new file mode 100644 index 000000000000..d899d0d4f670 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestHiddenColumnChunkMetaData.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import org.apache.parquet.column.Encoding; +import org.apache.parquet.crypto.HiddenColumnChunkMetaData; +import org.apache.parquet.crypto.HiddenColumnException; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestHiddenColumnChunkMetaData +{ + @Test + public void testIsHiddenColumn() + { + ColumnChunkMetaData column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + assertTrue(HiddenColumnChunkMetaData.isHiddenColumn(column)); + } + + @Test + public void testIsNotHiddenColumn() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + ColumnChunkMetaData column = ColumnChunkMetaData.get(ColumnPath.fromDotString("a.b.c"), BINARY, + CompressionCodecName.GZIP, encodingSet, -1, -1, -1, -1, -1); + assertFalse(HiddenColumnChunkMetaData.isHiddenColumn(column)); + } + + @Test(expectedExceptions = HiddenColumnException.class) + public void testHiddenColumnException() + { + ColumnChunkMetaData column = new HiddenColumnChunkMetaData(ColumnPath.fromDotString("a.b.c"), + "hdfs:/foo/bar/a.parquet"); + column.getStatistics(); + } + + @Test + public void testNoHiddenColumnException() + { + Set encodingSet = Collections.singleton(Encoding.RLE); + ColumnChunkMetaData column = ColumnChunkMetaData.get(ColumnPath.fromDotString("a.b.c"), BINARY, + CompressionCodecName.GZIP, encodingSet, -1, -1, -1, -1, -1); + column.getStatistics(); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index 2e2b9fe08d74..415cb31daa9e 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -109,7 +109,7 @@ public void testVariousTimestamps(TimestampType type, BiFunction { diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index 1d27c37c91eb..55d51aeb20f9 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -60,7 +60,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("time_millis_int32.snappy.parquet").toURI()), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); Page page = reader.nextPage(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java index bae20cf49539..446e057205b2 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java @@ -30,7 +30,7 @@ import org.apache.parquet.column.values.ValuesWriter; import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridValuesWriter; import org.apache.parquet.schema.PrimitiveType; -import org.junit.jupiter.api.Test; +import org.testng.annotations.Test; import java.io.IOException; import java.util.List; @@ -137,7 +137,7 @@ private static PageReader getSimplePageReaderMock(ParquetEncoding encoding) encoding, encoding, PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, Optional.empty(), null, null); } private static PageReader getNullOnlyPageReaderMock() @@ -154,6 +154,6 @@ private static PageReader getNullOnlyPageReaderMock() RLE, RLE, PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, Optional.empty(), null, null); } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index 555bc6275e60..f106a813ffbe 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -102,7 +102,7 @@ public void testWrittenPageSize() columnNames, generateInputPages(types, 100, 1000)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isEqualTo(1); assertThat(parquetMetadata.getBlocks().get(0).getRowCount()).isEqualTo(100 * 1000); @@ -116,7 +116,9 @@ public void testWrittenPageSize() chunkMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, - Optional.empty()); + Optional.empty(), + Optional.empty(), + -1); pageReader.readDictionaryPage(); assertThat(pageReader.hasNext()).isTrue(); @@ -164,7 +166,7 @@ public void testLargeStringTruncation() ImmutableList.of(new Page(2, blockA, blockB))), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); BlockMetaData blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetaData chunkMetaData = blockMetaData.getColumns().get(0); @@ -197,7 +199,7 @@ public void testColumnReordering() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(10); for (BlockMetaData blockMetaData : parquetMetadata.getBlocks()) { // Verify that the columns are stored in the same order as the metadata @@ -253,7 +255,7 @@ public void testDictionaryPageOffset() generateInputPages(types, 100, 100)), new ParquetReaderOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(1); for (BlockMetaData blockMetaData : parquetMetadata.getBlocks()) { ColumnChunkMetaData chunkMetaData = getOnlyElement(blockMetaData.getColumns()); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index e374d68f8bea..1da9f5c5bbe3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -512,7 +512,8 @@ private ReaderPageSource createParquetPageSource(Location path) new ParquetReaderOptions().withBloomFilter(false), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(fileSize)); + OptionalLong.of(fileSize), + fileSystem); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index 4a56b213cd53..2352aa96c7df 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -23,6 +23,7 @@ import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.TrinoInputFile; +import io.trino.parquet.EncryptionUtils; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.reader.MetadataReader; @@ -57,6 +58,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TypeManager; +import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.hadoop.metadata.FileMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.schema.MessageType; @@ -206,7 +208,7 @@ public ConnectorPageSource createPageSource( .withSmallFileThreshold(getParquetSmallFileThreshold(session)) .withUseColumnIndex(isParquetUseColumnIndex(session)); - Map parquetFieldIdToName = columnMappingMode == ColumnMappingMode.ID ? loadParquetIdAndNameMapping(inputFile, options) : ImmutableMap.of(); + Map parquetFieldIdToName = columnMappingMode == ColumnMappingMode.ID ? loadParquetIdAndNameMapping(inputFile, options, fileSystem) : ImmutableMap.of(); ImmutableSet.Builder missingColumnNames = ImmutableSet.builder(); ImmutableList.Builder hiveColumnHandles = ImmutableList.builder(); @@ -237,7 +239,8 @@ public ConnectorPageSource createPageSource( options, Optional.empty(), domainCompactionThreshold, - OptionalLong.of(split.getFileSize())); + OptionalLong.of(split.getFileSize()), + fileSystem); Optional projectionsAdapter = pageSource.getReaderColumns().map(readerColumns -> new ReaderProjectionsAdapter( @@ -286,10 +289,11 @@ private PositionDeleteFilter readDeletes( } } - public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) + public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options, TrinoFileSystem trinoFileSystem) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + final Optional fileDecryptor = EncryptionUtils.createDecryptor(options, inputFile.location(), trinoFileSystem); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), fileDecryptor); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index 7c27b8d151cd..fccea8484f0c 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -197,7 +197,7 @@ public DataFileInfo getDataFileInfo() private static DeltaLakeJsonFileStatistics readStatistics(FileMetaData fileMetaData, Location path, Map typeForColumn, long rowCount) throws IOException { - ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString())); + ParquetMetadata parquetMetadata = MetadataReader.createParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString()), Optional.empty(), false); ImmutableMultimap.Builder metadataForColumn = ImmutableMultimap.builder(); for (BlockMetaData blockMetaData : parquetMetadata.getBlocks()) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index d5cf1f482538..82b79b2480ce 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -203,7 +203,8 @@ private static DeltaLakePageSource createDeltaLakePageSource( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.empty()); + OptionalLong.empty(), + fileSystem); verify(pageSource.getReaderColumns().isEmpty(), "Unexpected reader columns: %s", pageSource.getReaderColumns().orElse(null)); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java index bdf2f428dc0f..de52337fad9f 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java @@ -213,7 +213,8 @@ public Stream getCheckpointTransactionLogEntries( checkpoint, checkpointFile, partitionConstraint, - addStatsMinMaxColumnFilter))); + addStatsMinMaxColumnFilter, + fileSystem))); } return resultStream; } @@ -234,7 +235,8 @@ private Iterator getCheckpointTransactionLogEntrie LastCheckpoint checkpoint, TrinoInputFile checkpointFile, TupleDomain partitionConstraint, - Optional> addStatsMinMaxColumnFilter) + Optional> addStatsMinMaxColumnFilter, + TrinoFileSystem fileSystem) throws IOException { long fileSize; @@ -258,7 +260,8 @@ private Iterator getCheckpointTransactionLogEntrie checkpointRowStatisticsWritingEnabled, domainCompactionThreshold, partitionConstraint, - addStatsMinMaxColumnFilter); + addStatsMinMaxColumnFilter, + fileSystem); } public record MetadataAndProtocolEntry(MetadataEntry metadataEntry, ProtocolEntry protocolEntry) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 9b90eaf50930..df8201576111 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.math.LongMath; import io.airlift.log.Logger; +import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoInputFile; import io.trino.parquet.Column; import io.trino.parquet.ParquetReaderOptions; @@ -161,7 +162,8 @@ public CheckpointEntryIterator( boolean checkpointRowStatisticsWritingEnabled, int domainCompactionThreshold, TupleDomain partitionConstraint, - Optional> addStatsMinMaxColumnFilter) + Optional> addStatsMinMaxColumnFilter, + TrinoFileSystem trinoFileSystem) { this.checkpointPath = checkpoint.location().toString(); this.session = requireNonNull(session, "session is null"); @@ -214,7 +216,8 @@ public CheckpointEntryIterator( parquetReaderOptions, Optional.empty(), domainCompactionThreshold, - OptionalLong.empty()); + OptionalLong.empty(), + trinoFileSystem); verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index 2d74baea49f7..1af384912480 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -310,7 +310,7 @@ private void testOptimizeWithColumnMappingMode(String columnMappingMode) TrinoInputFile inputFile = new LocalInputFile(tableLocation.resolve(addFileEntry.getPath()).toFile()); ParquetMetadata parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); PrimitiveType physicalType = getOnlyElement(fileMetaData.getSchema().getColumns().iterator()).getPrimitiveType(); assertThat(physicalType.getName()).isEqualTo(physicalName); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java index dcf83f52d74a..fe79d45a1688 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java @@ -1028,7 +1028,7 @@ private CheckpointEntryIterator createCheckpointEntryIterator( true, new DeltaLakeConfig().getDomainCompactionThreshold(), partitionConstraint, - addStatsMinMaxColumnFilter); + addStatsMinMaxColumnFilter, fileSystem); } private static TrinoOutputFile createOutputFile(String path) diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java index b928ae46be54..3e9195621efb 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java @@ -486,7 +486,7 @@ private CheckpointEntries readCheckpoint(String checkpointPath, MetadataEntry me rowStatisticsEnabled, new DeltaLakeConfig().getDomainCompactionThreshold(), TupleDomain.all(), - Optional.of(alwaysTrue())); + Optional.of(alwaysTrue()), fileSystem); CheckpointBuilder checkpointBuilder = new CheckpointBuilder(); while (checkpointEntryIterator.hasNext()) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/statistics/TestDeltaLakeFileStatistics.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/statistics/TestDeltaLakeFileStatistics.java index 8159b9b9fce2..8ca2f1c4f2b0 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/statistics/TestDeltaLakeFileStatistics.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/statistics/TestDeltaLakeFileStatistics.java @@ -108,7 +108,8 @@ public void testParseParquetStatistics() true, new DeltaLakeConfig().getDomainCompactionThreshold(), TupleDomain.all(), - Optional.empty()); + Optional.empty(), + null); MetadataEntry metadataEntry = getOnlyElement(metadataEntryIterator).getMetaData(); CheckpointEntryIterator protocolEntryIterator = new CheckpointEntryIterator( checkpointFile, @@ -124,7 +125,8 @@ public void testParseParquetStatistics() true, new DeltaLakeConfig().getDomainCompactionThreshold(), TupleDomain.all(), - Optional.empty()); + Optional.empty(), + null); ProtocolEntry protocolEntry = getOnlyElement(protocolEntryIterator).getProtocol(); CheckpointEntryIterator checkpointEntryIterator = new CheckpointEntryIterator( @@ -141,7 +143,8 @@ public void testParseParquetStatistics() true, new DeltaLakeConfig().getDomainCompactionThreshold(), TupleDomain.all(), - Optional.of(alwaysTrue())); + Optional.of(alwaysTrue()), + null); DeltaLakeTransactionLogEntry matchingAddFileEntry = null; while (checkpointEntryIterator.hasNext()) { DeltaLakeTransactionLogEntry entry = checkpointEntryIterator.next(); diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 76f7d9f3fbdd..eda6a955fd5e 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -264,6 +264,12 @@ provided + + io.trino.hadoop + hadoop-apache + provided + + org.jetbrains annotations @@ -434,12 +440,6 @@ test - - io.trino.hadoop - hadoop-apache - test - - io.trino.hive hive-apache diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 375e615fd5d6..7519ed9b540a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -23,6 +23,7 @@ import io.trino.filesystem.TrinoInputFile; import io.trino.memory.context.AggregatedMemoryContext; import io.trino.parquet.Column; +import io.trino.parquet.EncryptionUtils; import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; @@ -49,6 +50,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.hadoop.metadata.FileMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.MessageColumnIO; @@ -189,7 +191,8 @@ public Optional createPageSource( .withBloomFilter(useParquetBloomFilter(session)), Optional.empty(), domainCompactionThreshold, - OptionalLong.of(estimatedFileSize))); + OptionalLong.of(estimatedFileSize), + fileSystem)); } /** @@ -207,7 +210,8 @@ public static ReaderPageSource createPageSource( ParquetReaderOptions options, Optional parquetWriteValidation, int domainCompactionThreshold, - OptionalLong estimatedFileSize) + OptionalLong estimatedFileSize, + TrinoFileSystem trinoFileSystem) { MessageType fileSchema; MessageType requestedSchema; @@ -216,8 +220,9 @@ public static ReaderPageSource createPageSource( try { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, estimatedFileSize, options, memoryContext, stats); + final Optional fileDecryptor = EncryptionUtils.createDecryptor(options, inputFile.location(), trinoFileSystem); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation, fileDecryptor); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); fileSchema = fileMetaData.getSchema(); @@ -278,7 +283,8 @@ public static ReaderPageSource createPageSource( // We avoid using disjuncts of parquetPredicate for page pruning in ParquetReader as currently column indexes // are not present in the Parquet files which are read with disjunct predicates. parquetPredicates.size() == 1 ? Optional.of(parquetPredicates.get(0)) : Optional.empty(), - parquetWriteValidation); + parquetWriteValidation, + fileDecryptor); ConnectorPageSource parquetPageSource = createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); return new ReaderPageSource(parquetPageSource, readerProjections); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java index b4b1841f6e8e..a43992bba73b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetReaderConfig.java @@ -134,6 +134,87 @@ public boolean isUseBloomFilter() return options.useBloomFilter(); } + // TODO (wyu): The following properties will be directly translated to parquet modular encryption properties and passed to parquet lib. + // Another way to pass it is to use hive.config.resources. It could be better because 1. we don't need to translate all properties + // and 2. the parquet lib expects a hadoop.conf.Configuration class as an argument which is also the output of HdfsConfigurationInitializer. + @Config("parquet.crypto-factory-class") + @ConfigDescription("Crypto factory class to encrypt or decrypt parquet files") + public ParquetReaderConfig setCryptoFactoryClass(String cryptoFactoryClass) + { + options = options.withCryptoFactoryClass(cryptoFactoryClass); + return this; + } + + public String getCryptoFactoryClass() + { + return options.getCryptoFactoryClass(); + } + + @Config("parquet.encryption-kms-client-class") + @ConfigDescription("Class implementing the KmsClient interface. KMS stands for “key management service") + public ParquetReaderConfig setEncryptionKmsClientClass(String encryptionKmsClientClass) + { + options = options.withEncryptionKmsClientClass(encryptionKmsClientClass); + return this; + } + + public String getEncryptionKmsClientClass() + { + return options.getEncryptionKmsClientClass(); + } + + @Config("parquet.encryption-kms-instance-id") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKmsInstanceId(String encryptionKmsInstanceId) + { + options = options.withEncryptionKmsInstanceId(encryptionKmsInstanceId); + return this; + } + + public String getEncryptionKmsInstanceId() + { + return options.getEncryptionKmsInstanceId(); + } + + @Config("parquet.encryption-kms-instance-url") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKmsInstanceUrl(String encryptionKmsInstanceUrl) + { + options = options.withEncryptionKmsInstanceUrl(encryptionKmsInstanceUrl); + return this; + } + + public String getEncryptionKmsInstanceUrl() + { + return options.getEncryptionKmsInstanceUrl(); + } + + @Config("parquet.encryption-key-access-token") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionKeyAccessToken(String encryptionKeyAccessToken) + { + options = options.withEncryptionKeyAccessToken(encryptionKeyAccessToken); + return this; + } + + public String getEncryptionKeyAccessToken() + { + return options.getEncryptionKeyAccessToken(); + } + + @Config("parquet.encryption-cache-lifetime-seconds") + @ConfigDescription("") + public ParquetReaderConfig setEncryptionCacheLifetimeSeconds(Long encryptionCacheLifetimeSeconds) + { + options = options.withEncryptionCacheLifetimeSeconds(encryptionCacheLifetimeSeconds); + return this; + } + + public Long getEncryptionCacheLifetimeSeconds() + { + return options.getEncryptionCacheLifetimeSeconds(); + } + @Config("parquet.small-file-threshold") @ConfigDescription("Size below which a parquet file will be read entirely") public ParquetReaderConfig setSmallFileThreshold(DataSize smallFileThreshold) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java index 4c168c2b0247..a886b0aeb888 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java @@ -308,7 +308,7 @@ private static BloomFilterStore generateBloomFilterStore(ParquetTester.TempFile TrinoInputFile inputFile = new LocalInputFile(tempFile.getFile()); TrinoParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); ColumnChunkMetaData columnChunkMetaData = getOnlyElement(getOnlyElement(parquetMetadata.getBlocks()).getColumns()); return new BloomFilterStore(dataSource, getOnlyElement(parquetMetadata.getBlocks()), Set.of(columnChunkMetaData.getPath())); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java index bbbf23becea5..a9f83add4a8c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetReaderConfig.java @@ -53,6 +53,12 @@ public void testExplicitPropertyMappings() .put("parquet.use-column-index", "false") .put("parquet.use-bloom-filter", "false") .put("parquet.small-file-threshold", "1kB") + .put("parquet.crypto-factory-class", "factory.class") + .put("parquet.encryption-kms-client-class", "kms.client.class") + .put("parquet.encryption-key-access-token", "default_token") + .put("parquet.encryption-kms-instance-id", "kms_id") + .put("parquet.encryption-kms-instance-url", "kms_url") + .put("parquet.encryption-cache-lifetime-seconds", "3600") .buildOrThrow(); ParquetReaderConfig expected = new ParquetReaderConfig() @@ -63,7 +69,13 @@ public void testExplicitPropertyMappings() .setMaxMergeDistance(DataSize.of(342, KILOBYTE)) .setUseColumnIndex(false) .setUseBloomFilter(false) - .setSmallFileThreshold(DataSize.of(1, KILOBYTE)); + .setSmallFileThreshold(DataSize.of(1, KILOBYTE)) + .setCryptoFactoryClass("factory.class") + .setEncryptionKmsClientClass("kms.client.class") + .setEncryptionKeyAccessToken("default_token") + .setEncryptionKmsInstanceId("kms_id") + .setEncryptionKmsInstanceUrl("kms_url") + .setEncryptionCacheLifetimeSeconds(3600L); assertFullMapping(properties, expected); } diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java index 080c51746d49..698578f76378 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java @@ -20,6 +20,7 @@ import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.TrinoInputFile; import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.parquet.EncryptionUtils; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; @@ -48,6 +49,7 @@ import io.trino.spi.type.Decimals; import io.trino.spi.type.TypeSignature; import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.hadoop.metadata.FileMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.MessageColumnIO; @@ -167,7 +169,8 @@ public ConnectorPageSource createPageSource( inputFile, dataSourceStats, options.withSmallFileThreshold(getParquetSmallFileThreshold(session)), - timeZone); + timeZone, + fileSystem); return new HudiPageSource( toPartitionName(split.getPartitionKeys()), @@ -186,7 +189,8 @@ private static ConnectorPageSource createPageSource( TrinoInputFile inputFile, FileFormatDataSourceStats dataSourceStats, ParquetReaderOptions options, - DateTimeZone timeZone) + DateTimeZone timeZone, + TrinoFileSystem trinoFileSystem) { ParquetDataSource dataSource = null; boolean useColumnNames = shouldUseParquetColumnNames(session); @@ -196,7 +200,9 @@ private static ConnectorPageSource createPageSource( try { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, OptionalLong.of(hudiSplit.getFileSize()), options, memoryContext, dataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + final Optional fileDecryptor = EncryptionUtils.createDecryptor(options, inputFile.location(), trinoFileSystem); + + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), fileDecryptor); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); @@ -242,7 +248,8 @@ private static ConnectorPageSource createPageSource( options, exception -> handleException(dataSourceId, exception), Optional.of(parquetPredicate), - Optional.empty()); + Optional.empty(), + fileDecryptor); return createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } catch (IOException | RuntimeException e) { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 19bde2a57078..88382c50cd84 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -39,6 +39,7 @@ import io.trino.orc.TupleDomainOrcPredicate; import io.trino.orc.TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder; import io.trino.parquet.Column; +import io.trino.parquet.EncryptionUtils; import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; @@ -109,6 +110,7 @@ import org.apache.iceberg.util.StructLikeSet; import org.apache.iceberg.util.StructProjection; import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.hadoop.metadata.FileMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.ColumnIO; @@ -372,7 +374,8 @@ else if (identity.getId() == TRINO_MERGE_PARTITION_DATA) { requiredColumns, effectivePredicate, nameMapping, - partitionKeys); + partitionKeys, + fileSystem); ReaderPageSource dataPageSource = readerPageSourceWithRowPositions.getReaderPageSource(); Optional projectionsAdapter = dataPageSource.getReaderColumns().map(readerColumns -> @@ -528,7 +531,8 @@ private ConnectorPageSource openDeletes( columns, tupleDomain, Optional.empty(), - ImmutableMap.of()) + ImmutableMap.of(), + fileSystem) .getReaderPageSource() .get(); } @@ -546,7 +550,8 @@ public ReaderPageSourceWithRowPositions createDataPageSource( List dataColumns, TupleDomain predicate, Optional nameMapping, - Map> partitionKeys) + Map> partitionKeys, + TrinoFileSystem fileSystem) { switch (fileFormat) { case ORC: @@ -590,7 +595,8 @@ public ReaderPageSourceWithRowPositions createDataPageSource( predicate, fileFormatDataSourceStats, nameMapping, - partitionKeys); + partitionKeys, + fileSystem); case AVRO: return createAvroPageSource( inputFile, @@ -928,14 +934,16 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( TupleDomain effectivePredicate, FileFormatDataSourceStats fileFormatDataSourceStats, Optional nameMapping, - Map> partitionKeys) + Map> partitionKeys, + TrinoFileSystem trinoFileSystem) { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); ParquetDataSource dataSource = null; try { dataSource = createDataSource(inputFile, OptionalLong.of(fileSize), options, memoryContext, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); + Optional fileDecryptor = EncryptionUtils.createDecryptor(options, inputFile.location(), trinoFileSystem); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty(), fileDecryptor); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); if (nameMapping.isPresent() && !ParquetSchemaUtil.hasIds(fileSchema)) { @@ -1046,7 +1054,8 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { options, exception -> handleException(dataSourceId, exception), Optional.empty(), - Optional.empty()); + Optional.empty(), + fileDecryptor); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( pageSourceBuilder.build(parquetReader), diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java index 91283376cbb6..34ad81ea1d55 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrateProcedure.java @@ -404,7 +404,7 @@ private static Metrics loadMetrics(TrinoInputFile file, HiveStorageFormat storag private static Metrics parquetMetrics(TrinoInputFile file, MetricsConfig metricsConfig, NameMapping nameMapping) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(file, new ParquetReaderOptions(), new FileFormatDataSourceStats())) { - ParquetMetadata metadata = MetadataReader.readFooter(dataSource, Optional.empty()); + ParquetMetadata metadata = MetadataReader.readFooter(dataSource, Optional.empty(), Optional.empty()); return ParquetUtil.footerMetrics(metadata, Stream.empty(), metricsConfig, nameMapping); } catch (IOException e) { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java index a4716d6b5701..d779ca3a5117 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java @@ -118,7 +118,7 @@ public static boolean checkParquetFileSorting(TrinoInputFile inputFile, String s try { parquetMetadata = MetadataReader.readFooter( new TrinoParquetDataSource(inputFile, new ParquetReaderOptions(), new FileFormatDataSourceStats()), - Optional.empty()); + Optional.empty(), Optional.empty()); } catch (IOException e) { throw new UncheckedIOException(e);