diff --git a/docs/src/main/sphinx/connector/hive.md b/docs/src/main/sphinx/connector/hive.md index b4606eba756c..76192dc548e9 100644 --- a/docs/src/main/sphinx/connector/hive.md +++ b/docs/src/main/sphinx/connector/hive.md @@ -417,6 +417,43 @@ limitations and differences: - `GRANT privilege ON SCHEMA schema` is not supported. Schema ownership can be changed with `ALTER SCHEMA schema SET AUTHORIZATION user` +(hive-parquet-encryption)= +## Parquet encryption + +The Hive connector supports reading Parquet files encrypted with Parquet +Modular Encryption (PME). Decryption keys can be provided via environment +variables. Writing encrypted Parquet files is not supported. + +:::{list-table} Parquet encryption properties +:widths: 35, 50, 15 +:header-rows: 1 + +* - Property name + - Description + - Default +* - `pme.environment-key-retriever.enabled` + - Enable the key retriever that reads decryption keys from + environment variables. + - `false` +* - `pme.aad-prefix` + - AAD prefix used when decoding Parquet files. Must match the prefix used + when the files were written, if applicable. + - +* - `pme.check-footer-integrity` + - Validate signature for plaintext footer files. + - `true` +::: + +When `pme.environment-key-retriever.enabled` is set, provide keys with +environment variables: + +- `pme.environment-key-retriever.footer-keys` +- `pme.environment-key-retriever.column-keys` + +Each variable accepts either a single base64-encoded key, or a comma-separated +list of `id:key` pairs (base64-encoded keys) where `id` must match the key +metadata embedded in the Parquet file. + (hive-sql-support)= ## SQL support diff --git a/lib/trino-parquet/pom.xml b/lib/trino-parquet/pom.xml index e906de967a52..d8dffec7fc94 100644 --- a/lib/trino-parquet/pom.xml +++ b/lib/trino-parquet/pom.xml @@ -101,6 +101,12 @@ + + io.trino + trino-filesystem + provided + + io.trino trino-spi @@ -193,6 +199,12 @@ test + + org.apache.parquet + parquet-hadoop + test + + org.assertj assertj-core diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/BloomFilterStore.java b/lib/trino-parquet/src/main/java/io/trino/parquet/BloomFilterStore.java index 1afc665c1494..806625682082 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/BloomFilterStore.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/BloomFilterStore.java @@ -16,10 +16,15 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.values.bloomfilter.BlockSplitBloomFilter; import org.apache.parquet.column.values.bloomfilter.BloomFilter; @@ -48,21 +53,27 @@ public class BloomFilterStore private final ParquetDataSource dataSource; private final Map bloomFilterOffsets; + private final Map columnChunks; + private final Optional decryptionContext; - public BloomFilterStore(ParquetDataSource dataSource, BlockMetadata block, Set columnsFiltered) + public BloomFilterStore(ParquetDataSource dataSource, BlockMetadata block, Set columnsFiltered, Optional decryptionContext) { this.dataSource = requireNonNull(dataSource, "dataSource is null"); requireNonNull(block, "block is null"); requireNonNull(columnsFiltered, "columnsFiltered is null"); + this.decryptionContext = requireNonNull(decryptionContext, "decryptionContext is null"); ImmutableMap.Builder bloomFilterOffsetBuilder = ImmutableMap.builder(); + ImmutableMap.Builder chunkBuilder = ImmutableMap.builder(); for (ColumnChunkMetadata column : block.columns()) { ColumnPath path = column.getPath(); if (hasBloomFilter(column) && columnsFiltered.contains(path)) { bloomFilterOffsetBuilder.put(path, column.getBloomFilterOffset()); + chunkBuilder.put(path, column); } } this.bloomFilterOffsets = bloomFilterOffsetBuilder.buildOrThrow(); + this.columnChunks = chunkBuilder.buildOrThrow(); } public Optional getBloomFilter(ColumnPath columnPath) @@ -74,9 +85,24 @@ public Optional getBloomFilter(ColumnPath columnPath) if (columnBloomFilterOffset == null) { return Optional.empty(); } - BasicSliceInput headerSliceInput = dataSource.readFully(columnBloomFilterOffset, MAX_HEADER_LENGTH).getInput(); - bloomFilterHeader = Util.readBloomFilterHeader(headerSliceInput); - bloomFilterDataOffset = columnBloomFilterOffset + headerSliceInput.position(); + // If the column is encrypted, decrypt the header using the metadata decryptor + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(columnPath)); + if (columnContext.isPresent()) { + // Read encrypted header module: SIZE(4) + NONCE + CIPHERTEXT + TAG + int encryptedSize = BytesUtils.readIntLittleEndian(dataSource.readFully(columnBloomFilterOffset, AesCipherUtils.SIZE_LENGTH).getBytes(), 0); + Slice module = dataSource.readFully(columnBloomFilterOffset, AesCipherUtils.SIZE_LENGTH + encryptedSize); + BasicSliceInput in = module.getInput(); + ColumnChunkMetadata chunk = requireNonNull(columnChunks.get(columnPath), "missing chunk metadata"); + byte[] aad = AesCipherUtils.createModuleAAD(columnContext.get().fileAad(), ModuleType.BloomFilterHeader, chunk.getRowGroupOrdinal(), chunk.getColumnOrdinal(), -1); + bloomFilterHeader = Util.readBloomFilterHeader(in, columnContext.get().metadataDecryptor(), aad); + // after read, position() == 4 + encrypted data length + bloomFilterDataOffset = columnBloomFilterOffset + in.position(); + } + else { + BasicSliceInput headerSliceInput = dataSource.readFully(columnBloomFilterOffset, MAX_HEADER_LENGTH).getInput(); + bloomFilterHeader = Util.readBloomFilterHeader(headerSliceInput); + bloomFilterDataOffset = columnBloomFilterOffset + headerSliceInput.position(); + } } catch (IOException exception) { throw new UncheckedIOException("Failed to read Bloom filter header", exception); @@ -87,9 +113,23 @@ public Optional getBloomFilter(ColumnPath columnPath) } try { - Slice bloomFilterData = dataSource.readFully(bloomFilterDataOffset, bloomFilterHeader.getNumBytes()); - verify(bloomFilterData.length() > 0, "Read empty bloom filter %s", bloomFilterHeader); - return Optional.of(new BlockSplitBloomFilter(bloomFilterData.getBytes())); + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(columnPath)); + if (columnContext.isPresent()) { + // Read the whole bitset module: SIZE + NONCE + CIPHERTEXT + TAG + int encryptedSize = BytesUtils.readIntLittleEndian(dataSource.readFully(bloomFilterDataOffset, AesCipherUtils.SIZE_LENGTH).getBytes(), 0); + Slice module = dataSource.readFully(bloomFilterDataOffset, AesCipherUtils.SIZE_LENGTH + encryptedSize); + ColumnChunkMetadata chunk = requireNonNull(columnChunks.get(columnPath), "missing chunk metadata"); + byte[] aad = AesCipherUtils.createModuleAAD(columnContext.get().fileAad(), ModuleType.BloomFilterBitset, chunk.getRowGroupOrdinal(), chunk.getColumnOrdinal(), -1); + byte[] plain = columnContext.get().metadataDecryptor().decrypt(module.getBytes(), aad); + verify(plain.length == bloomFilterHeader.getNumBytes(), "Decrypted bloom filter length mismatch: expected %s, got %s", bloomFilterHeader.getNumBytes(), plain.length); + return Optional.of(new BlockSplitBloomFilter(plain)); + } + else { + // Plaintext bitset + Slice bloomFilterData = dataSource.readFully(bloomFilterDataOffset, bloomFilterHeader.getNumBytes()); + verify(bloomFilterData.length() > 0, "Read empty bloom filter %s", bloomFilterHeader); + return Optional.of(new BlockSplitBloomFilter(bloomFilterData.getBytes())); + } } catch (IOException exception) { throw new UncheckedIOException("Failed to read Bloom filter data", exception); @@ -100,7 +140,8 @@ public static Optional getBloomFilterStore( ParquetDataSource dataSource, BlockMetadata blockMetadata, TupleDomain parquetTupleDomain, - ParquetReaderOptions options) + ParquetReaderOptions options, + Optional decryptionContext) { if (!options.useBloomFilter() || parquetTupleDomain.isAll() || parquetTupleDomain.isNone()) { return Optional.empty(); @@ -117,7 +158,7 @@ public static Optional getBloomFilterStore( .map(column -> ColumnPath.get(column.getPath())) .collect(toImmutableSet()); - return Optional.of(new BloomFilterStore(dataSource, blockMetadata, columnsFilteredPaths)); + return Optional.of(new BloomFilterStore(dataSource, blockMetadata, columnsFilteredPaths, decryptionContext)); } public static boolean hasBloomFilter(ColumnChunkMetadata columnMetaData) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java index bbece17c9b7e..8eaebc60c93d 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPage.java @@ -21,12 +21,14 @@ public abstract sealed class DataPage { protected final int valueCount; private final OptionalLong firstRowIndex; + private final int pageIndex; - public DataPage(int uncompressedSize, int valueCount, OptionalLong firstRowIndex) + public DataPage(int uncompressedSize, int valueCount, OptionalLong firstRowIndex, int pageIndex) { super(uncompressedSize); this.valueCount = valueCount; this.firstRowIndex = firstRowIndex; + this.pageIndex = pageIndex; } /** @@ -41,4 +43,9 @@ public int getValueCount() { return valueCount; } + + public int getPageIndex() + { + return pageIndex; + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java index b0895445d813..8dbf9809378d 100755 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV1.java @@ -35,15 +35,17 @@ public DataPageV1( OptionalLong firstRowIndex, ParquetEncoding repetitionLevelEncoding, ParquetEncoding definitionLevelEncoding, - ParquetEncoding valuesEncoding) + ParquetEncoding valuesEncoding, + int pageIndex) { - super(uncompressedSize, valueCount, firstRowIndex); + super(uncompressedSize, valueCount, firstRowIndex, pageIndex); this.slice = requireNonNull(slice, "slice is null"); this.repetitionLevelEncoding = repetitionLevelEncoding; this.definitionLevelEncoding = definitionLevelEncoding; this.valuesEncoding = valuesEncoding; } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java index b0cbfd9ed8fc..6544942e74eb 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DataPageV2.java @@ -44,9 +44,10 @@ public DataPageV2( int uncompressedSize, OptionalLong firstRowIndex, Statistics statistics, - boolean isCompressed) + boolean isCompressed, + int pageIndex) { - super(uncompressedSize, valueCount, firstRowIndex); + super(uncompressedSize, valueCount, firstRowIndex, pageIndex); this.rowCount = rowCount; this.nullCount = nullCount; this.repetitionLevels = requireNonNull(repetitionLevels, "repetitionLevels slice is null"); @@ -82,6 +83,7 @@ public ParquetEncoding getDataEncoding() return dataEncoding; } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java b/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java index 74fdf540199d..bd92d7fc0c8e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/DictionaryPage.java @@ -43,6 +43,7 @@ public DictionaryPage(Slice slice, int uncompressedSize, int dictionarySize, Par encoding); } + @Override public Slice getSlice() { return slice; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java b/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java index 69cde62cf435..64b1f861717b 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/Page.java @@ -13,6 +13,8 @@ */ package io.trino.parquet; +import io.airlift.slice.Slice; + public abstract class Page { protected final int uncompressedSize; @@ -26,4 +28,6 @@ public int getUncompressedSize() { return uncompressedSize; } + + public abstract Slice getSlice(); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java index 5e31a08f704e..5a1d06861f60 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java @@ -14,6 +14,7 @@ package io.trino.parquet; import com.google.errorprone.annotations.FormatMethod; +import io.trino.parquet.crypto.ParquetCryptoException; public final class ParquetValidationUtils { @@ -27,4 +28,12 @@ public static void validateParquet(boolean condition, ParquetDataSourceId dataSo throw new ParquetCorruptionException(dataSourceId, formatString, args); } } + + @FormatMethod + public static void validateParquetCrypto(boolean condition, ParquetDataSourceId dataSourceId, String formatString, Object... args) + { + if (!condition) { + throw new ParquetCryptoException(dataSourceId, formatString, args); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipherUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipherUtils.java new file mode 100644 index 000000000000..25f872daa68f --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCipherUtils.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import java.io.IOException; +import java.io.InputStream; + +import static com.google.common.primitives.Bytes.concat; +import static java.util.Objects.requireNonNull; + +public final class AesCipherUtils +{ + public static final int SIZE_LENGTH = 4; + public static final int NONCE_LENGTH = 12; + public static final int GCM_TAG_LENGTH = 16; + public static final int CTR_IV_LENGTH = 16; + public static final int GCM_TAG_LENGTH_BITS = 8 * GCM_TAG_LENGTH; + public static final int CHUNK_LENGTH = 4 * 1024; + // NIST SP 800-38D section 8.3 specifies limit on AES GCM encryption operations with same key and random IV/nonce + public static final long GCM_RANDOM_IV_SAME_KEY_MAX_OPS = 1L << 32; + + private AesCipherUtils() {} + + public static void validateKeyBytes(byte[] keyBytes) + { + requireNonNull(keyBytes, "key bytes cannot be null"); + boolean allZeroKey = true; + for (byte kb : keyBytes) { + if (kb != 0) { + allZeroKey = false; + break; + } + } + + if (allZeroKey) { + throw new IllegalArgumentException("All key bytes are zero"); + } + } + + public static byte[] createModuleAAD(byte[] fileAAD, ModuleType moduleType, int rowGroupOrdinal, int columnOrdinal, int pageOrdinal) + { + byte[] typeOrdinalBytes = new byte[1]; + typeOrdinalBytes[0] = moduleType.getValue(); + + if (ModuleType.Footer == moduleType) { + return concat(fileAAD, typeOrdinalBytes); + } + + if (rowGroupOrdinal < 0) { + throw new IllegalArgumentException("Wrong row group ordinal: " + rowGroupOrdinal); + } + short shortRGOrdinal = (short) rowGroupOrdinal; + if (shortRGOrdinal != rowGroupOrdinal) { + throw new ParquetCryptoException("Encrypted parquet files can't have more than %s row groups: %s", Short.MAX_VALUE, rowGroupOrdinal); + } + byte[] rowGroupOrdinalBytes = shortToBytesLittleEndian(shortRGOrdinal); + + if (columnOrdinal < 0) { + throw new IllegalArgumentException("Wrong column ordinal: " + columnOrdinal); + } + short shortColumOrdinal = (short) columnOrdinal; + if (shortColumOrdinal != columnOrdinal) { + throw new ParquetCryptoException("Encrypted parquet files can't have more than %s columns: %s", Short.MAX_VALUE, columnOrdinal); + } + byte[] columnOrdinalBytes = shortToBytesLittleEndian(shortColumOrdinal); + + if (ModuleType.DataPage != moduleType && ModuleType.DataPageHeader != moduleType) { + return concat(fileAAD, typeOrdinalBytes, rowGroupOrdinalBytes, columnOrdinalBytes); + } + + if (pageOrdinal < 0) { + throw new IllegalArgumentException("Wrong page ordinal: " + pageOrdinal); + } + short shortPageOrdinal = (short) pageOrdinal; + if (shortPageOrdinal != pageOrdinal) { + throw new ParquetCryptoException("Encrypted parquet files can't have more than %s pages per chunk: %s", Short.MAX_VALUE, pageOrdinal); + } + byte[] pageOrdinalBytes = shortToBytesLittleEndian(shortPageOrdinal); + + return concat(fileAAD, typeOrdinalBytes, rowGroupOrdinalBytes, columnOrdinalBytes, pageOrdinalBytes); + } + + public static byte[] createFooterAAD(byte[] aadPrefixBytes) + { + return createModuleAAD(aadPrefixBytes, ModuleType.Footer, -1, -1, -1); + } + + // Update last two bytes with new page ordinal (instead of creating new page AAD from scratch) + public static void quickUpdatePageAAD(byte[] pageAAD, int newPageOrdinal) + { + requireNonNull(pageAAD, "pageAAD cannot be null"); + if (newPageOrdinal < 0) { + throw new IllegalArgumentException("Wrong page ordinal: " + newPageOrdinal); + } + short shortPageOrdinal = (short) newPageOrdinal; + if (shortPageOrdinal != newPageOrdinal) { + throw new ParquetCryptoException("Encrypted parquet files can't have more than %s pages per chunk: %s", Short.MAX_VALUE, newPageOrdinal); + } + + byte[] pageOrdinalBytes = shortToBytesLittleEndian(shortPageOrdinal); + System.arraycopy(pageOrdinalBytes, 0, pageAAD, pageAAD.length - 2, 2); + } + + public static int readCiphertextLength(InputStream from) + throws IOException + { + byte[] lengthBuffer = new byte[SIZE_LENGTH]; + int readBytes = 0; + + // Read the length of encrypted Thrift structure + while (readBytes < SIZE_LENGTH) { + int n = from.read(lengthBuffer, readBytes, SIZE_LENGTH - readBytes); + if (n <= 0) { + throw new IOException("Tried to read int (4 bytes), but only got " + readBytes + " bytes."); + } + readBytes += n; + } + + int ciphertextLength = ((lengthBuffer[3] & 0xff) << 24) + | ((lengthBuffer[2] & 0xff) << 16) + | ((lengthBuffer[1] & 0xff) << 8) + | (lengthBuffer[0] & 0xff); + + if (ciphertextLength < 1) { + throw new IOException("Wrong length of encrypted metadata: " + ciphertextLength); + } + + return ciphertextLength; + } + + private static byte[] shortToBytesLittleEndian(short input) + { + byte[] output = new byte[2]; + output[1] = (byte) (0xff & (input >> 8)); + output[0] = (byte) (0xff & input); + + return output; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java new file mode 100644 index 000000000000..b4740c63ccae --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesCtrDecryptor.java @@ -0,0 +1,186 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.Cipher; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Objects; + +import static io.trino.parquet.crypto.AesCipherUtils.CHUNK_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.CTR_IV_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.NONCE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.SIZE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.readCiphertextLength; +import static io.trino.parquet.crypto.AesCipherUtils.validateKeyBytes; + +public class AesCtrDecryptor + implements BlockCipher.Decryptor +{ + private final byte[] keyBytes; + private final Cipher cipher; + private final SecretKeySpec aesKey; + private final byte[] ctrIV; + + public AesCtrDecryptor(byte[] keyBytes) + { + validateKeyBytes(keyBytes); + this.keyBytes = keyBytes; + + try { + cipher = Cipher.getInstance(AesMode.CTR.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to create CTR cipher"); + } + aesKey = new SecretKeySpec(keyBytes, "AES"); + ctrIV = new byte[CTR_IV_LENGTH]; + // Setting last bit of initial CTR counter to 1 + ctrIV[CTR_IV_LENGTH - 1] = (byte) 1; + } + + @Override + public byte[] decrypt(byte[] lengthAndCiphertext, byte[] aad) + { + return decrypt(lengthAndCiphertext, SIZE_LENGTH, lengthAndCiphertext.length - SIZE_LENGTH, aad); + } + + public byte[] decrypt(byte[] ciphertext, int cipherTextOffset, int cipherTextLength, byte[] aad) + throws ParquetCryptoException + { + int plainTextLength = cipherTextLength - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoException("Wrong input length %s", plainTextLength); + } + + // Get the nonce from ciphertext + System.arraycopy(ciphertext, cipherTextOffset, ctrIV, 0, NONCE_LENGTH); + + byte[] plainText = new byte[plainTextLength]; + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + int outputOffset = 0; + try { + IvParameterSpec spec = new IvParameterSpec(ctrIV); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + + // Breaking decryption into multiple updates, to trigger h/w acceleration in Java 9+ + while (inputLength > CHUNK_LENGTH) { + int written = cipher.update(ciphertext, inputOffset, CHUNK_LENGTH, plainText, outputOffset); + inputOffset += CHUNK_LENGTH; + outputOffset += written; + inputLength -= CHUNK_LENGTH; + } + + cipher.doFinal(ciphertext, inputOffset, inputLength, plainText, outputOffset); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to decrypt"); + } + + return plainText; + } + + @Override + public ByteBuffer decrypt(ByteBuffer ciphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = ciphertext.limit() - ciphertext.position() - SIZE_LENGTH; + + int plainTextLength = cipherTextLength - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoException("Wrong input length %s", plainTextLength); + } + + // skip size + ciphertext.position(ciphertext.position() + cipherTextOffset); + // Get the nonce from ciphertext + ciphertext.get(ctrIV, 0, NONCE_LENGTH); + + // Reuse the input buffer as the output buffer + ByteBuffer plainText = ciphertext.slice(); + plainText.limit(plainTextLength); + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + try { + IvParameterSpec spec = new IvParameterSpec(ctrIV); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + + // Breaking decryption into multiple updates, to trigger h/w acceleration in Java 9+ + while (inputLength > CHUNK_LENGTH) { + ciphertext.position(inputOffset); + ciphertext.limit(inputOffset + CHUNK_LENGTH); + cipher.update(ciphertext, plainText); + inputOffset += CHUNK_LENGTH; + inputLength -= CHUNK_LENGTH; + } + ciphertext.position(inputOffset); + ciphertext.limit(inputOffset + inputLength); + cipher.doFinal(ciphertext, plainText); + plainText.flip(); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to decrypt"); + } + + return plainText; + } + + @Override + public byte[] decrypt(InputStream from, byte[] aad) + throws IOException + { + int ciphertextLength = readCiphertextLength(from); + // Read the encrypted structure contents + byte[] ciphertextBuffer = new byte[ciphertextLength]; + int readBytes = 0; + while (readBytes < ciphertextLength) { + int n = from.read(ciphertextBuffer, readBytes, ciphertextLength - readBytes); + if (n <= 0) { + throw new IOException( + "Tried to read " + ciphertextLength + " bytes, but only got " + readBytes + " bytes."); + } + readBytes += n; + } + + // Decrypt the structure contents + return decrypt(ciphertextBuffer, 0, ciphertextLength, aad); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (!(o instanceof AesCtrDecryptor that)) { + return false; + } + return Objects.deepEquals(keyBytes, that.keyBytes); + } + + @Override + public int hashCode() + { + return Arrays.hashCode(keyBytes); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java new file mode 100644 index 000000000000..d1b716be5cd1 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmDecryptor.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.AEADBadTagException; +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Objects; + +import static io.trino.parquet.crypto.AesCipherUtils.GCM_TAG_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.GCM_TAG_LENGTH_BITS; +import static io.trino.parquet.crypto.AesCipherUtils.NONCE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.SIZE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.readCiphertextLength; +import static io.trino.parquet.crypto.AesCipherUtils.validateKeyBytes; + +public class AesGcmDecryptor + implements BlockCipher.Decryptor +{ + private final byte[] keyBytes; + private final Cipher cipher; + private final SecretKeySpec aesKey; + private final byte[] localNonce; + + public AesGcmDecryptor(byte[] keyBytes) + { + validateKeyBytes(keyBytes); + this.keyBytes = keyBytes; + + try { + cipher = Cipher.getInstance(AesMode.GCM.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to create GCM cipher"); + } + + aesKey = new SecretKeySpec(keyBytes, "AES"); + localNonce = new byte[NONCE_LENGTH]; + } + + @Override + public byte[] decrypt(byte[] lengthAndCiphertext, byte[] aad) + { + return decrypt(lengthAndCiphertext, SIZE_LENGTH, lengthAndCiphertext.length - SIZE_LENGTH, aad); + } + + public byte[] decrypt(byte[] ciphertext, int cipherTextOffset, int cipherTextLength, byte[] aad) + { + int plainTextLength = cipherTextLength - GCM_TAG_LENGTH - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoException("Wrong input length %s", plainTextLength); + } + + // Get the nonce from ciphertext + System.arraycopy(ciphertext, cipherTextOffset, localNonce, 0, NONCE_LENGTH); + + byte[] plainText = new byte[plainTextLength]; + int inputLength = cipherTextLength - NONCE_LENGTH; + int inputOffset = cipherTextOffset + NONCE_LENGTH; + int outputOffset = 0; + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, localNonce); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(ciphertext, inputOffset, inputLength, plainText, outputOffset); + } + catch (AEADBadTagException e) { + throw new ParquetCryptoException(e, "GCM tag check failed"); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to decrypt"); + } + + return plainText; + } + + @Override + public ByteBuffer decrypt(ByteBuffer ciphertext, byte[] aad) + { + int cipherTextOffset = SIZE_LENGTH; + int cipherTextLength = ciphertext.limit() - ciphertext.position() - SIZE_LENGTH; + int plainTextLength = cipherTextLength - GCM_TAG_LENGTH - NONCE_LENGTH; + if (plainTextLength < 1) { + throw new ParquetCryptoException("Wrong input length %s", plainTextLength); + } + + ciphertext.position(ciphertext.position() + cipherTextOffset); + // Get the nonce from ciphertext + ciphertext.get(localNonce); + + // Reuse the input buffer as the output buffer + ByteBuffer plainText = ciphertext.slice(); + plainText.limit(plainTextLength); + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, localNonce); + cipher.init(Cipher.DECRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(ciphertext, plainText); + plainText.flip(); + } + catch (AEADBadTagException e) { + throw new ParquetCryptoException(e, "GCM tag check failed"); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to decrypt"); + } + + return plainText; + } + + @Override + public byte[] decrypt(InputStream from, byte[] aad) + throws IOException + { + int ciphertextLength = readCiphertextLength(from); + // Read the encrypted structure contents + byte[] ciphertextBuffer = new byte[ciphertextLength]; + int readBytes = 0; + // Read the encrypted structure contents + while (readBytes < ciphertextLength) { + int n = from.read(ciphertextBuffer, readBytes, ciphertextLength - readBytes); + if (n <= 0) { + throw new IOException("Tried to read " + ciphertextLength + " bytes, but only got " + readBytes + " bytes."); + } + readBytes += n; + } + + // Decrypt the structure contents + return decrypt(ciphertextBuffer, 0, ciphertextLength, aad); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (!(o instanceof AesGcmDecryptor that)) { + return false; + } + return Objects.deepEquals(keyBytes, that.keyBytes); + } + + @Override + public int hashCode() + { + return Arrays.hashCode(keyBytes); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java new file mode 100644 index 000000000000..f2034731214d --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesGcmEncryptor.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.format.BlockCipher; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import java.security.GeneralSecurityException; +import java.security.SecureRandom; + +import static io.trino.parquet.crypto.AesCipherUtils.GCM_RANDOM_IV_SAME_KEY_MAX_OPS; +import static io.trino.parquet.crypto.AesCipherUtils.GCM_TAG_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.GCM_TAG_LENGTH_BITS; +import static io.trino.parquet.crypto.AesCipherUtils.NONCE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.SIZE_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.validateKeyBytes; + +public class AesGcmEncryptor + implements BlockCipher.Encryptor +{ + protected Cipher cipher; + protected final SecureRandom randomGenerator; + protected final byte[] localNonce; + protected SecretKeySpec aesKey; + + private long operationCounter; + + public AesGcmEncryptor(byte[] keyBytes) + { + validateKeyBytes(keyBytes); + operationCounter = 0; + + try { + cipher = Cipher.getInstance(AesMode.GCM.getCipherName()); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to create GCM cipher"); + } + + aesKey = new SecretKeySpec(keyBytes, "AES"); + randomGenerator = new SecureRandom(); + localNonce = new byte[NONCE_LENGTH]; + } + + @Override + public byte[] encrypt(byte[] plainText, byte[] aad) + { + return encrypt(true, plainText, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] aad) + { + randomGenerator.nextBytes(localNonce); + return encrypt(writeLength, plainText, localNonce, aad); + } + + public byte[] encrypt(boolean writeLength, byte[] plainText, byte[] nonce, byte[] aad) + { + if (operationCounter > GCM_RANDOM_IV_SAME_KEY_MAX_OPS) { + throw new ParquetCryptoException("Exceeded limit of AES GCM encryption operations with same key and random IV"); + } + operationCounter++; + + if (nonce.length != NONCE_LENGTH) { + throw new ParquetCryptoException("Wrong nonce length %s", nonce.length); + } + int plainTextLength = plainText.length; + int cipherTextLength = NONCE_LENGTH + plainTextLength + GCM_TAG_LENGTH; + int lengthBufferLength = writeLength ? SIZE_LENGTH : 0; + byte[] cipherText = new byte[lengthBufferLength + cipherTextLength]; + int inputOffset = 0; + int outputOffset = lengthBufferLength + NONCE_LENGTH; + + try { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH_BITS, nonce); + cipher.init(Cipher.ENCRYPT_MODE, aesKey, spec); + if (null != aad) { + cipher.updateAAD(aad); + } + + cipher.doFinal(plainText, inputOffset, plainTextLength, cipherText, outputOffset); + } + catch (GeneralSecurityException e) { + throw new ParquetCryptoException(e, "Failed to encrypt"); + } + + // Add ciphertext length + if (writeLength) { + System.arraycopy(BytesUtils.intToBytes(cipherTextLength), 0, cipherText, 0, lengthBufferLength); + } + // Add the nonce + System.arraycopy(nonce, 0, cipherText, lengthBufferLength, NONCE_LENGTH); + + return cipherText; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java new file mode 100644 index 000000000000..e8affac6c9f0 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/AesMode.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public enum AesMode +{ + GCM("AES/GCM/NoPadding"), + CTR("AES/CTR/NoPadding"); + + private final String cipherName; + + AesMode(String cipherName) + { + this.cipherName = cipherName; + } + + public String getCipherName() + { + return cipherName; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionContext.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionContext.java new file mode 100644 index 000000000000..c251c2149c7d --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ColumnDecryptionContext.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.format.BlockCipher.Decryptor; + +import java.util.Arrays; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public record ColumnDecryptionContext(Decryptor dataDecryptor, Decryptor metadataDecryptor, byte[] fileAad) +{ + public ColumnDecryptionContext(Decryptor dataDecryptor, Decryptor metadataDecryptor, byte[] fileAad) + { + this.dataDecryptor = requireNonNull(dataDecryptor, "dataDecryptor is null"); + this.metadataDecryptor = requireNonNull(metadataDecryptor, "metadataDecryptor is null"); + this.fileAad = requireNonNull(fileAad, "fileAad is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (!(o instanceof ColumnDecryptionContext context)) { + return false; + } + return Objects.equals(dataDecryptor, context.dataDecryptor) + && Objects.equals(metadataDecryptor, context.metadataDecryptor) + && Objects.deepEquals(fileAad, context.fileAad); + } + + @Override + public int hashCode() + { + return Objects.hash(dataDecryptor, metadataDecryptor, Arrays.hashCode(fileAad)); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java new file mode 100644 index 000000000000..130a84534bbf --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/DecryptionKeyRetriever.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.util.Optional; + +/** + * Interface for classes retrieving encryption keys using the key metadata. + * Implementations must be thread-safe, if same {@link DecryptionKeyRetriever} object is passed to multiple file readers. + */ +public interface DecryptionKeyRetriever +{ + /** + * Returns key for a given column and the key metadata. Should return empty if user does not have access to the column key or key doesn't exist. + */ + Optional getColumnKey(ColumnPath columnPath, Optional keyMetadata); + + /** + * Returns key for a footer and the key metadata. Should return empty if user does not have access to the column key or key doesn't exist. + */ + Optional getFooterKey(Optional keyMetadata); +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionContext.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionContext.java new file mode 100644 index 000000000000..f77b8fa37f47 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionContext.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import io.airlift.log.Logger; +import io.trino.parquet.ParquetDataSourceId; +import org.apache.parquet.format.BlockCipher.Decryptor; +import org.apache.parquet.format.EncryptionAlgorithm; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.primitives.Bytes.concat; +import static io.trino.parquet.ParquetValidationUtils.validateParquetCrypto; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FileDecryptionContext +{ + private static final Logger log = Logger.get(FileDecryptionContext.class); + + private final Map> columnDecryptionContext = new HashMap<>(); + + private final ParquetDataSourceId dataSourceId; + private final DecryptionKeyRetriever keyRetriever; + private final EncryptionAlgorithm algorithm; + private final Optional footerKey; + private final byte[] fileAad; + + private AesGcmDecryptor aesGcmDecryptorWithFooterKey; + private AesCtrDecryptor aesCtrDecryptorWithFooterKey; + + public FileDecryptionContext(ParquetDataSourceId dataSourceId, FileDecryptionProperties fileDecryptionProperties, EncryptionAlgorithm algorithm, Optional footerKeyMetadata) + { + log.debug("File Decryptor. Algo: %s", algorithm); + + requireNonNull(fileDecryptionProperties, "fileDecryptionProperties is null"); + this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); + this.keyRetriever = fileDecryptionProperties.getKeyRetriever(); + this.algorithm = requireNonNull(algorithm, "algorithm is null"); + + byte[] aadFileUnique; + boolean mustSupplyAadPrefix; + byte[] aadPrefixInFile = null; + + // Process encryption algorithm metadata + if (algorithm.isSetAES_GCM_V1()) { + if (algorithm.getAES_GCM_V1().isSetAad_prefix()) { + aadPrefixInFile = algorithm.getAES_GCM_V1().getAad_prefix(); + } + mustSupplyAadPrefix = algorithm.getAES_GCM_V1().isSupply_aad_prefix(); + aadFileUnique = algorithm.getAES_GCM_V1().getAad_file_unique(); + } + else if (algorithm.isSetAES_GCM_CTR_V1()) { + if (algorithm.getAES_GCM_CTR_V1().isSetAad_prefix()) { + aadPrefixInFile = algorithm.getAES_GCM_CTR_V1().getAad_prefix(); + } + mustSupplyAadPrefix = algorithm.getAES_GCM_CTR_V1().isSupply_aad_prefix(); + aadFileUnique = algorithm.getAES_GCM_CTR_V1().getAad_file_unique(); + } + else { + throw new UnsupportedOperationException(format("Unsupported algorithm: %s", algorithm)); + } + + // Determine AAD prefix, in-file AAD prefix takes precedence + Optional aadPrefix = Optional.ofNullable(aadPrefixInFile).or(fileDecryptionProperties::getAadPrefix); + validateParquetCrypto(!mustSupplyAadPrefix || aadPrefix.isPresent(), dataSourceId, "AAD prefix must be supplied"); + fileAad = aadPrefix.map(bytes -> concat(bytes, aadFileUnique)).orElse(aadFileUnique); + footerKey = fileDecryptionProperties.getKeyRetriever().getFooterKey(footerKeyMetadata); + } + + public Optional getColumnDecryptionContext(ColumnPath path) + { + Optional context = columnDecryptionContext.get(path); + checkArgument(context != null, "Column %s not found in decryption context", path); + return context; + } + + public Decryptor getFooterDecryptor() + { + validateParquetCrypto(footerKey.isPresent(), dataSourceId, "User does not have access to footer or footer key does not exists"); + return getThriftModuleDecryptor(Optional.empty()); + } + + public AesGcmEncryptor getFooterEncryptor() + { + validateParquetCrypto(footerKey.isPresent(), dataSourceId, "User does not have access to footer or footer key does not exists"); + return new AesGcmEncryptor(footerKey.get()); + } + + public byte[] getFileAad() + { + return this.fileAad; + } + + public void initPlaintextColumn(ColumnPath path) + { + log.debug("Column decryption (plaintext): %s", path); + setColumnDecryptionContext(path, Optional.empty()); + } + + public Optional initializeColumnCryptoMetadata(ColumnPath path, boolean encryptedWithFooterKey, Optional columnKeyMetadata) + { + log.debug("Column decryption (footer key): %s", path); + + Optional context; + if (encryptedWithFooterKey) { + if (footerKey.isEmpty()) { + // User does not have access to the footer key. Column is considered hidden. + setColumnDecryptionContext(path, Optional.empty()); + return Optional.empty(); + } + context = Optional.of(new ColumnDecryptionContext(getDataModuleDecryptor(Optional.empty()), getThriftModuleDecryptor(Optional.empty()), fileAad)); + } + else { + // Column is encrypted with column-specific key + Optional columnKeyBytes = requireNonNull(keyRetriever.getColumnKey(path, columnKeyMetadata), format("Column key for %s not found", path)); + if (columnKeyBytes.isEmpty()) { + // User does not have access to the column key. Column is considered hidden. + setColumnDecryptionContext(path, Optional.empty()); + return Optional.empty(); + } + context = Optional.of(new ColumnDecryptionContext(getDataModuleDecryptor(columnKeyBytes), getThriftModuleDecryptor(columnKeyBytes), fileAad)); + } + + setColumnDecryptionContext(path, context); + return context; + } + + private void setColumnDecryptionContext(ColumnPath path, Optional context) + { + checkArgument(!columnDecryptionContext.containsKey(path) || columnDecryptionContext.get(path).equals(context), "Mismatching column %s encryption context already exists in decryption context", path); + columnDecryptionContext.put(path, context); + } + + private Decryptor getDataModuleDecryptor(Optional columnKey) + { + if (algorithm.isSetAES_GCM_V1()) { + return getThriftModuleDecryptor(columnKey); + } + + // AES_GCM_CTR_V1 + if (columnKey.isEmpty()) { + // Decryptor with footer key + if (aesCtrDecryptorWithFooterKey == null) { + aesCtrDecryptorWithFooterKey = new AesCtrDecryptor(footerKey.get()); + } + return aesCtrDecryptorWithFooterKey; + } + else { + // Decryptor with column key + return new AesCtrDecryptor(columnKey.orElseThrow()); + } + } + + private Decryptor getThriftModuleDecryptor(Optional columnKey) + { + if (columnKey.isEmpty()) { + // Decryptor with footer key + if (aesGcmDecryptorWithFooterKey == null) { + aesGcmDecryptorWithFooterKey = new AesGcmDecryptor(footerKey.get()); + } + return aesGcmDecryptorWithFooterKey; + } + + // Decryptor with column key + return new AesGcmDecryptor(columnKey.get()); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java new file mode 100644 index 000000000000..574ca1252829 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/FileDecryptionProperties.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class FileDecryptionProperties +{ + private final DecryptionKeyRetriever keyRetriever; + private final Optional aadPrefix; + private final boolean checkFooterIntegrity; + + private FileDecryptionProperties(DecryptionKeyRetriever keyRetriever, Optional aadPrefix, boolean checkFooterIntegrity) + { + this.keyRetriever = requireNonNull(keyRetriever, "keyRetriever is null"); + this.aadPrefix = requireNonNull(aadPrefix, "aadPrefix is null"); + this.checkFooterIntegrity = checkFooterIntegrity; + } + + public static Builder builder() + { + return new Builder(); + } + + public DecryptionKeyRetriever getKeyRetriever() + { + return keyRetriever; + } + + public Optional getAadPrefix() + { + return aadPrefix; + } + + public boolean isCheckFooterIntegrity() + { + return checkFooterIntegrity; + } + + public static class Builder + { + private DecryptionKeyRetriever keyRetriever; + private Optional aadPrefix = Optional.empty(); + private boolean checkFooterIntegrity = true; + + public Builder withKeyRetriever(DecryptionKeyRetriever keyRetriever) + { + this.keyRetriever = requireNonNull(keyRetriever, "keyRetriever is null"); + return this; + } + + public Builder withAadPrefix(byte[] aadPrefix) + { + this.aadPrefix = Optional.of(requireNonNull(aadPrefix, "aadPrefix is null")); + return this; + } + + public Builder withCheckFooterIntegrity(boolean checkFooterIntegrity) + { + this.checkFooterIntegrity = checkFooterIntegrity; + return this; + } + + public FileDecryptionProperties build() + { + return new FileDecryptionProperties(keyRetriever, aadPrefix, checkFooterIntegrity); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleType.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleType.java new file mode 100644 index 000000000000..af56ad50a54c --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ModuleType.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +public enum ModuleType +{ + Footer((byte) 0), + ColumnMetaData((byte) 1), + DataPage((byte) 2), + DictionaryPage((byte) 3), + DataPageHeader((byte) 4), + DictionaryPageHeader((byte) 5), + ColumnIndex((byte) 6), + OffsetIndex((byte) 7), + BloomFilterHeader((byte) 8), + BloomFilterBitset((byte) 9); + + private final byte value; + + ModuleType(byte value) + { + this.value = value; + } + + public byte getValue() + { + return value; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoException.java new file mode 100644 index 000000000000..32d18aea2b47 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/crypto/ParquetCryptoException.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import com.google.errorprone.annotations.FormatMethod; +import io.trino.parquet.ParquetDataSourceId; + +import java.util.Optional; + +import static java.lang.String.format; + +public class ParquetCryptoException + extends RuntimeException +{ + @FormatMethod + public ParquetCryptoException(String messageFormat, Object... args) + { + super(formatMessage(Optional.empty(), messageFormat, args)); + } + + @FormatMethod + public ParquetCryptoException(Throwable cause, String messageFormat, Object... args) + { + super(formatMessage(Optional.empty(), messageFormat, args), cause); + } + + @FormatMethod + public ParquetCryptoException(ParquetDataSourceId dataSourceId, String messageFormat, Object... args) + { + super(formatMessage(Optional.of(dataSourceId), messageFormat, args)); + } + + private static String formatMessage(Optional dataSourceId, String messageFormat, Object[] args) + { + if (dataSourceId.isEmpty()) { + return "Parquet cryptographic error. " + format(messageFormat, args); + } + + return "Parquet cryptographic error. " + format(messageFormat, args) + " [" + dataSourceId + "]"; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java index 939bc399037e..c057a6207689 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/BlockMetadata.java @@ -17,8 +17,4 @@ public record BlockMetadata(long fileRowCountOffset, long rowCount, List columns) { - public long getStartingPos() - { - return columns().getFirst().getStartingPos(); - } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java index 381260829869..62d3d9869d67 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ColumnChunkMetadata.java @@ -23,9 +23,13 @@ import java.util.Set; +import static org.apache.parquet.column.Encoding.PLAIN_DICTIONARY; +import static org.apache.parquet.column.Encoding.RLE_DICTIONARY; + public abstract class ColumnChunkMetadata { - protected int rowGroupOrdinal = -1; + private int rowGroupOrdinal = -1; + private int columnOrdinal = -1; public static ColumnChunkMetadata get( ColumnPath path, @@ -76,6 +80,16 @@ public int getRowGroupOrdinal() return rowGroupOrdinal; } + public void setColumnOrdinal(int columnOrdinal) + { + this.columnOrdinal = columnOrdinal; + } + + public int getColumnOrdinal() + { + return columnOrdinal; + } + public long getStartingPos() { decryptIfNeeded(); @@ -194,6 +208,18 @@ public EncodingStats getEncodingStats() return encodingStats; } + public boolean hasDictionaryPage() + { + decryptIfNeeded(); + if (encodingStats != null) { + // ensure there is a dictionary page and that it is used to encode data pages + return encodingStats.hasDictionaryPages() && encodingStats.hasDictionaryEncodedPages(); + } + + Set encodings = properties.encodings(); + return (encodings.contains(PLAIN_DICTIONARY) || encodings.contains(RLE_DICTIONARY)); + } + @Override public String toString() { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/HiddenColumnChunkMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/HiddenColumnChunkMetadata.java new file mode 100644 index 000000000000..95db2cb4ca9c --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/HiddenColumnChunkMetadata.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.metadata; + +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.ParquetCryptoException; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import static java.util.Objects.requireNonNull; + +public class HiddenColumnChunkMetadata + extends ColumnChunkMetadata +{ + private final ParquetDataSourceId dataSourceId; + private final ColumnPath path; + + public HiddenColumnChunkMetadata(ParquetDataSourceId dataSourceId, ColumnPath path) + { + super(null, null); + this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); + this.path = requireNonNull(path, "path is null"); + } + + @Override + public ColumnPath getPath() + { + return path; + } + + @Override + public long getFirstDataPageOffset() + { + throw hiddenColumnException(); + } + + @Override + public long getDictionaryPageOffset() + { + throw hiddenColumnException(); + } + + @Override + public long getValueCount() + { + throw hiddenColumnException(); + } + + @Override + public long getTotalUncompressedSize() + { + throw hiddenColumnException(); + } + + @Override + public long getTotalSize() + { + throw hiddenColumnException(); + } + + @Override + public Statistics getStatistics() + { + throw hiddenColumnException(); + } + + public static boolean isHiddenColumn(ColumnChunkMetadata column) + { + return column instanceof HiddenColumnChunkMetadata; + } + + private ParquetCryptoException hiddenColumnException() + { + return new ParquetCryptoException(dataSourceId, "User does not have access to column: %s or column key does not exists", path); + } + + @Override + public String toString() + { + return "HiddenColumnChunkMetadata{dataSourceId=" + dataSourceId + ", path=" + path + "}"; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ParquetMetadata.java b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ParquetMetadata.java index f0c640bd0dd8..6d39b434d6d9 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ParquetMetadata.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/metadata/ParquetMetadata.java @@ -19,14 +19,21 @@ import io.airlift.log.Logger; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.reader.MetadataReader; import org.apache.parquet.column.Encoding; import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnMetaData; +import org.apache.parquet.format.EncryptionWithColumnKey; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.SchemaElement; +import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -35,6 +42,8 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -54,6 +63,7 @@ import static io.trino.parquet.ParquetMetadataConverter.toColumnIndexReference; import static io.trino.parquet.ParquetMetadataConverter.toOffsetIndexReference; import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetValidationUtils.validateParquetCrypto; import static java.util.Objects.requireNonNull; public class ParquetMetadata @@ -63,8 +73,9 @@ public class ParquetMetadata private final FileMetaData parquetMetadata; private final ParquetDataSourceId dataSourceId; private final FileMetadata fileMetadata; + private final Optional decryptionContext; - public ParquetMetadata(FileMetaData parquetMetadata, ParquetDataSourceId dataSourceId) + public ParquetMetadata(FileMetaData parquetMetadata, ParquetDataSourceId dataSourceId, Optional decryptionContext) throws ParquetCorruptionException { this.fileMetadata = new FileMetadata( @@ -73,6 +84,7 @@ public ParquetMetadata(FileMetaData parquetMetadata, ParquetDataSourceId dataSou parquetMetadata.getCreated_by()); this.parquetMetadata = parquetMetadata; this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); + this.decryptionContext = requireNonNull(decryptionContext, "decryptionContext is null"); } public FileMetadata getFileMetaData() @@ -80,6 +92,11 @@ public FileMetadata getFileMetaData() return fileMetadata; } + public Optional getDecryptionContext() + { + return decryptionContext; + } + @Override public String toString() { @@ -89,13 +106,13 @@ public String toString() } public List getBlocks() - throws ParquetCorruptionException + throws IOException { return getBlocks(0, Long.MAX_VALUE); } public List getBlocks(long splitStart, long splitLength) - throws ParquetCorruptionException + throws IOException { List schema = parquetMetadata.getSchema(); validateParquet(!schema.isEmpty(), dataSourceId, "Schema is empty"); @@ -114,21 +131,81 @@ public List getBlocks(long splitStart, long splitLength) List columns = rowGroup.getColumns(); validateParquet(!columns.isEmpty(), dataSourceId, "No columns in row group: %s", rowGroup); String filePath = columns.get(0).getFile_path(); - long rowGroupStart = getRowGroupStart(columns, messageType); - boolean splitContainsRowGroup = splitStart <= rowGroupStart && rowGroupStart < splitStart + splitLength; - if (!splitContainsRowGroup) { - continue; - } ImmutableList.Builder columnMetadataBuilder = ImmutableList.builderWithExpectedSize(columns.size()); + int columnOrdinal = -1; + boolean splitContainsRowGroup = true; for (ColumnChunk columnChunk : columns) { + columnOrdinal++; validateParquet( (filePath == null && columnChunk.getFile_path() == null) || (filePath != null && filePath.equals(columnChunk.getFile_path())), dataSourceId, "all column chunks of the same row group must be in the same file"); - ColumnChunkMetadata column = toColumnChunkMetadata(columnChunk, parquetMetadata.getCreated_by(), messageType); + ColumnCryptoMetaData cryptoMetaData = columnChunk.getCrypto_metadata(); + ColumnMetaData metaData; + ColumnPath columnPath; + if (cryptoMetaData == null) { + // Plaintext column + metaData = columnChunk.getMeta_data(); + columnPath = getPath(metaData.getPath_in_schema()); + decryptionContext.ifPresent(context -> context.initPlaintextColumn(columnPath)); + } + else { + validateParquetCrypto(decryptionContext.isPresent(), dataSourceId, "Column is encrypted, but no decryption context"); + if (cryptoMetaData.isSetENCRYPTION_WITH_FOOTER_KEY()) { + // Column encrypted with footer key + validateParquetCrypto(columnChunk.getMeta_data() != null, dataSourceId, "Column metadata is null"); + metaData = columnChunk.getMeta_data(); + columnPath = getPath(metaData.getPath_in_schema()); + decryptionContext.get().initializeColumnCryptoMetadata(columnPath, true, Optional.empty()); + } + else { + // Column encrypted with column key + EncryptionWithColumnKey columnKeyStruct = cryptoMetaData.getENCRYPTION_WITH_COLUMN_KEY(); + columnPath = getPath(columnKeyStruct.getPath_in_schema()); + Optional decryptedMetadata = decryptColumnMetadata(columnPath, rowGroup, columnKeyStruct.getKey_metadata(), columnChunk, decryptionContext.get(), columnOrdinal); + if (decryptedMetadata.isEmpty()) { + // User does not have access to the column key. Column is considered hidden. + columnMetadataBuilder.add(new HiddenColumnChunkMetadata(dataSourceId, columnPath)); + validateParquetCrypto(columnOrdinal != 0, dataSourceId, "First column of a row group is encrypted with an unknown column key. Cannot determine row group starting position."); + continue; + } + metaData = decryptedMetadata.get(); + } + } + + PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); + ColumnChunkMetadata column = ColumnChunkMetadata.get( + columnPath, + primitiveType, + CompressionCodecName.fromParquet(metaData.codec), + convertEncodingStats(metaData.encoding_stats), + readEncodings(metaData.encodings), + MetadataReader.readStats(Optional.ofNullable(parquetMetadata.getCreated_by()), Optional.ofNullable(metaData.statistics), primitiveType), + metaData.data_page_offset, + metaData.dictionary_page_offset, + metaData.num_values, + metaData.total_compressed_size, + metaData.total_uncompressed_size); + column.setColumnIndexReference(toColumnIndexReference(columnChunk)); + column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); + column.setBloomFilterOffset(metaData.bloom_filter_offset); + if (rowGroup.isSetOrdinal()) { + column.setRowGroupOrdinal(rowGroup.getOrdinal()); + } + column.setColumnOrdinal(columnOrdinal); columnMetadataBuilder.add(column); + + // Skip row group if it doesn't overlap the split. Only first column starting position matches row group start and can be used for the check. + long rowGroupStart = getRowGroupStart(column); + splitContainsRowGroup = columnOrdinal != 0 || (splitStart <= rowGroupStart && rowGroupStart < splitStart + splitLength); + if (!splitContainsRowGroup) { + break; + } + } + if (!splitContainsRowGroup) { + continue; } blocks.add(new BlockMetadata(fileRowCountOffset, rowGroup.getNum_rows(), columnMetadataBuilder.build())); } @@ -143,38 +220,11 @@ public FileMetaData getParquetMetadata() return parquetMetadata; } - private static long getRowGroupStart(List columns, MessageType messageType) + private static long getRowGroupStart(ColumnChunkMetadata column) { // Note: Do not rely on org.apache.parquet.format.RowGroup.getFile_offset or org.apache.parquet.format.ColumnChunk.getFile_offset // because some versions of parquet-cpp-arrow (and potentially other writers) set it incorrectly - ColumnChunkMetadata columnChunkMetadata = toColumnChunkMetadata(columns.getFirst(), null, messageType); - return columnChunkMetadata.getStartingPos(); - } - - private static ColumnChunkMetadata toColumnChunkMetadata(ColumnChunk columnChunk, String createdBy, MessageType messageType) - { - ColumnMetaData metaData = columnChunk.meta_data; - String[] path = metaData.path_in_schema.stream() - .map(value -> value.toLowerCase(Locale.ENGLISH)) - .toArray(String[]::new); - ColumnPath columnPath = ColumnPath.get(path); - PrimitiveType primitiveType = messageType.getType(columnPath.toArray()).asPrimitiveType(); - ColumnChunkMetadata column = ColumnChunkMetadata.get( - columnPath, - primitiveType, - CompressionCodecName.fromParquet(metaData.codec), - convertEncodingStats(metaData.encoding_stats), - readEncodings(metaData.encodings), - MetadataReader.readStats(Optional.ofNullable(createdBy), Optional.ofNullable(metaData.statistics), primitiveType), - metaData.data_page_offset, - metaData.dictionary_page_offset, - metaData.num_values, - metaData.total_compressed_size, - metaData.total_uncompressed_size); - column.setColumnIndexReference(toColumnIndexReference(columnChunk)); - column.setOffsetIndexReference(toOffsetIndexReference(columnChunk)); - column.setBloomFilterOffset(metaData.bloom_filter_offset); - return column; + return column.getStartingPos(); } private static MessageType readParquetSchema(List schema) @@ -248,6 +298,31 @@ private static void readTypeSchema(Types.GroupBuilder builder, Iterator decryptColumnMetadata(ColumnPath columnPath, RowGroup rowGroup, byte[] columnKeyMetadata, ColumnChunk columnChunk, FileDecryptionContext decryptionContext, int columnOrdinal) + throws IOException + { + byte[] encryptedMetadataBuffer = columnChunk.getEncrypted_column_metadata(); + + // Decrypt the ColumnMetaData + Optional columnDecryptionContext = decryptionContext.initializeColumnCryptoMetadata(columnPath, false, Optional.ofNullable(columnKeyMetadata)); + if (columnDecryptionContext.isEmpty()) { + return Optional.empty(); + } + + ByteArrayInputStream tempInputStream = new ByteArrayInputStream(encryptedMetadataBuffer); + byte[] columnMetaDataAAD = AesCipherUtils.createModuleAAD(decryptionContext.getFileAad(), ModuleType.ColumnMetaData, rowGroup.ordinal, columnOrdinal, -1); + return Optional.of(Util.readColumnMetaData(tempInputStream, columnDecryptionContext.get().metadataDecryptor(), columnMetaDataAAD)); + } + + private static ColumnPath getPath(List pathInSchema) + { + requireNonNull(pathInSchema, "pathInSchema is null"); + String[] path = pathInSchema.stream() + .map(value -> value.toLowerCase(Locale.ENGLISH)) + .toArray(String[]::new); + return ColumnPath.get(path); + } + private static Set readEncodings(List encodings) { Set columnEncodings = new HashSet<>(); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java index 3230c7190a0a..1a34a62659b3 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; +import io.airlift.slice.Slices; import io.trino.parquet.BloomFilterStore; import io.trino.parquet.DictionaryPage; import io.trino.parquet.ParquetCorruptionException; @@ -25,6 +26,10 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetEncoding; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.ParquetMetadata; @@ -36,10 +41,10 @@ import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.PageHeader; import org.apache.parquet.format.PageType; -import org.apache.parquet.format.Util; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import org.apache.parquet.internal.filter2.columnindex.ColumnIndexStore; import org.apache.parquet.io.ParquetDecodingException; @@ -69,6 +74,7 @@ import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static org.apache.parquet.format.Util.readPageHeader; public final class PredicateUtils { @@ -142,7 +148,8 @@ public static boolean predicateMatches( Optional columnIndexStore, Optional bloomFilterStore, DateTimeZone timeZone, - int domainCompactionThreshold) + int domainCompactionThreshold, + Optional decryptionContext) throws IOException { if (columnsMetadata.getRowCount() == 0) { @@ -177,7 +184,8 @@ public static boolean predicateMatches( dataSource, descriptorsByPath, ImmutableSet.copyOf(candidateColumns.get()), - columnIndexStore); + columnIndexStore, + decryptionContext); } public static List getFilteredRowGroups( @@ -198,8 +206,8 @@ public static List getFilteredRowGroups( for (int i = 0; i < parquetTupleDomains.size(); i++) { TupleDomain parquetTupleDomain = parquetTupleDomains.get(i); TupleDomainParquetPredicate parquetPredicate = parquetPredicates.get(i); - Optional columnIndex = getColumnIndexStore(dataSource, block, descriptorsByPath, parquetTupleDomain, options); - Optional bloomFilterStore = getBloomFilterStore(dataSource, block, parquetTupleDomain, options); + Optional columnIndex = getColumnIndexStore(dataSource, block, descriptorsByPath, parquetTupleDomain, options, parquetMetadata.getDecryptionContext()); + Optional bloomFilterStore = getBloomFilterStore(dataSource, block, parquetTupleDomain, options, parquetMetadata.getDecryptionContext()); PrunedBlockMetadata columnsMetadata = createPrunedColumnsMetadata(block, dataSource.getId(), descriptorsByPath); if (predicateMatches( parquetPredicate, @@ -210,7 +218,8 @@ public static List getFilteredRowGroups( columnIndex, bloomFilterStore, timeZone, - domainCompactionThreshold)) { + domainCompactionThreshold, + parquetMetadata.getDecryptionContext())) { rowGroupInfoBuilder.add(new RowGroupInfo(columnsMetadata, block.fileRowCountOffset(), columnIndex)); break; } @@ -250,7 +259,8 @@ private static boolean dictionaryPredicatesMatch( ParquetDataSource dataSource, Map, ColumnDescriptor> descriptorsByPath, Set candidateColumns, - Optional columnIndexStore) + Optional columnIndexStore, + Optional decryptionContext) throws IOException { for (ColumnDescriptor descriptor : descriptorsByPath.values()) { @@ -265,7 +275,7 @@ private static boolean dictionaryPredicatesMatch( if (!parquetPredicate.matches(new DictionaryDescriptor( descriptor, nullAllowed, - readDictionaryPage(dataSource, columnMetaData, columnIndexStore)))) { + readDictionaryPage(dataSource, columnMetaData, columnIndexStore, decryptionContext)))) { return false; } } @@ -276,7 +286,8 @@ private static boolean dictionaryPredicatesMatch( private static Optional readDictionaryPage( ParquetDataSource dataSource, ColumnChunkMetadata columnMetaData, - Optional columnIndexStore) + Optional columnIndexStore, + Optional decryptionContext) throws IOException { int dictionaryPageSize; @@ -300,7 +311,8 @@ private static Optional readDictionaryPage( } // Get the dictionary page header and the dictionary in single read Slice buffer = dataSource.readFully(columnMetaData.getStartingPos(), dictionaryPageSize); - return readPageHeaderWithData(buffer.getInput()).map(data -> decodeDictionaryPage(dataSource.getId(), data, columnMetaData)); + return readPageHeaderWithData(buffer.getInput(), columnMetaData, decryptionContext) + .map(data -> decodeDictionaryPage(dataSource.getId(), data, columnMetaData, decryptionContext)); } private static Optional getDictionaryPageSize(ColumnIndexStore columnIndexStore, ColumnChunkMetadata columnMetaData) @@ -317,11 +329,20 @@ private static Optional getDictionaryPageSize(ColumnIndexStore columnIn return Optional.empty(); } - private static Optional readPageHeaderWithData(SliceInput inputStream) + private static Optional readPageHeaderWithData(SliceInput inputStream, ColumnChunkMetadata columnMetaData, Optional decryptionContext) { - PageHeader pageHeader; + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(columnMetaData.getPath())); + BlockCipher.Decryptor decryptor = null; + byte[] headerAad = null; + if (columnContext.isPresent()) { + decryptor = columnContext.map(ColumnDecryptionContext::metadataDecryptor).orElse(null); + byte[] fileAad = decryptionContext.get().getFileAad(); + headerAad = AesCipherUtils.createModuleAAD(fileAad, ModuleType.DictionaryPageHeader, columnMetaData.getRowGroupOrdinal(), columnMetaData.getColumnOrdinal(), -1); + } + + final PageHeader pageHeader; try { - pageHeader = Util.readPageHeader(inputStream); + pageHeader = readPageHeader(inputStream, decryptor, headerAad); } catch (IOException e) { throw new UncheckedIOException(e); @@ -339,7 +360,7 @@ private static Optional readPageHeaderWithData(SliceInput in inputStream.readSlice(pageHeader.getCompressed_page_size()))); } - private static DictionaryPage decodeDictionaryPage(ParquetDataSourceId dataSourceId, PageHeaderWithData pageHeaderWithData, ColumnChunkMetadata chunkMetaData) + private static DictionaryPage decodeDictionaryPage(ParquetDataSourceId dataSourceId, PageHeaderWithData pageHeaderWithData, ColumnChunkMetadata chunkMetaData, Optional decryptionContext) { PageHeader pageHeader = pageHeaderWithData.pageHeader(); DictionaryPageHeader dicHeader = pageHeader.getDictionary_page_header(); @@ -347,8 +368,20 @@ private static DictionaryPage decodeDictionaryPage(ParquetDataSourceId dataSourc int dictionarySize = dicHeader.getNum_values(); Slice compressedData = pageHeaderWithData.compressedData(); + Slice maybeDecrypted = compressedData; + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(chunkMetaData.getPath())); + if (columnContext.isPresent()) { + byte[] aad = AesCipherUtils.createModuleAAD( + columnContext.get().fileAad(), + ModuleType.DictionaryPage, + chunkMetaData.getRowGroupOrdinal(), + chunkMetaData.getColumnOrdinal(), + -1); + byte[] plain = columnContext.get().dataDecryptor().decrypt(compressedData.getBytes(), aad); + maybeDecrypted = Slices.wrappedBuffer(plain); + } try { - return new DictionaryPage(decompress(dataSourceId, chunkMetaData.getCodec().getParquetCompressionCodec(), compressedData, pageHeader.getUncompressed_page_size()), dictionarySize, encoding); + return new DictionaryPage(decompress(dataSourceId, chunkMetaData.getCodec().getParquetCompressionCodec(), maybeDecrypted, pageHeader.getUncompressed_page_size()), dictionarySize, encoding); } catch (IOException e) { throw new ParquetDecodingException("Could not decode the dictionary for " + chunkMetaData.getPath(), e); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index 369ce467e131..8808d592995a 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -14,56 +14,68 @@ package io.trino.parquet.reader; import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; import io.airlift.slice.Slices; import io.airlift.units.DataSize; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.AesGcmEncryptor; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionProperties; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.format.BlockCipher.Decryptor; +import org.apache.parquet.format.FileCryptoMetaData; import org.apache.parquet.format.FileMetaData; import org.apache.parquet.format.Statistics; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType; import java.io.IOException; -import java.io.InputStream; import java.util.Arrays; import java.util.Optional; import static io.trino.parquet.ParquetMetadataConverter.fromParquetStatistics; import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetValidationUtils.validateParquetCrypto; +import static io.trino.parquet.crypto.AesCipherUtils.GCM_TAG_LENGTH; +import static io.trino.parquet.crypto.AesCipherUtils.NONCE_LENGTH; import static java.lang.Boolean.FALSE; import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static java.lang.System.arraycopy; +import static org.apache.parquet.format.Util.readFileCryptoMetaData; import static org.apache.parquet.format.Util.readFileMetaData; public final class MetadataReader { private static final Slice MAGIC = Slices.utf8Slice("PAR1"); + private static final Slice EMAGIC = Slices.utf8Slice("PARE"); private static final int POST_SCRIPT_SIZE = Integer.BYTES + MAGIC.length(); // Typical 1GB files produced by Trino were found to have footer size between 30-40KB private static final int EXPECTED_FOOTER_SIZE = 48 * 1024; private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional fileDecryptionProperties) throws IOException { - return readFooter(dataSource, Optional.empty(), Optional.empty()); + return readFooter(dataSource, Optional.empty(), Optional.empty(), fileDecryptionProperties); } - public static ParquetMetadata readFooter(ParquetDataSource dataSource, DataSize maxFooterReadSize) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, DataSize maxFooterReadSize, Optional fileDecryptionProperties) throws IOException { - return readFooter(dataSource, Optional.of(maxFooterReadSize), Optional.empty()); + return readFooter(dataSource, Optional.of(maxFooterReadSize), Optional.empty(), fileDecryptionProperties); } - public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional maxFooterReadSize, Optional parquetWriteValidation) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional maxFooterReadSize, Optional parquetWriteValidation, Optional fileDecryptionProperties) throws IOException { // Parquet File Layout: @@ -82,8 +94,9 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< Slice buffer = dataSource.readTail(toIntExact(expectedReadSize)); Slice magic = buffer.slice(buffer.length() - MAGIC.length(), MAGIC.length()); - validateParquet(MAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s got: %s", MAGIC.toStringUtf8(), magic.toStringUtf8()); + validateParquet(MAGIC.equals(magic) || EMAGIC.equals(magic), dataSource.getId(), "Expected magic number: %s or %s got: %s", MAGIC.toStringUtf8(), EMAGIC.toStringUtf8(), magic.toStringUtf8()); + boolean encryptedFooterMode = EMAGIC.equals(magic); int metadataLength = buffer.getInt(buffer.length() - POST_SCRIPT_SIZE); long metadataIndex = estimatedFileSize - POST_SCRIPT_SIZE - metadataLength; validateParquet( @@ -104,10 +117,32 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional< // initial read was not large enough, so just read again with the correct size buffer = dataSource.readTail(completeFooterSize); } - InputStream metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); + SliceInput metadataStream = buffer.slice(buffer.length() - completeFooterSize, metadataLength).getInput(); - FileMetaData fileMetaData = readFileMetaData(metadataStream); - ParquetMetadata parquetMetadata = new ParquetMetadata(fileMetaData, dataSource.getId()); + Optional decryptionContext = Optional.empty(); + Decryptor footerDecryptor = null; + byte[] aad = null; + + if (encryptedFooterMode) { + validateParquetCrypto(fileDecryptionProperties.isPresent(), dataSource.getId(), "fileDecryptionProperties cannot be null when encryptedFooterMode is true"); + FileCryptoMetaData fileCryptoMetaData = readFileCryptoMetaData(metadataStream); + validateParquetCrypto(fileCryptoMetaData != null, dataSource.getId(), "FileCryptoMetaData cannot be null when encryptedFooterMode is true"); + decryptionContext = Optional.of(new FileDecryptionContext(dataSource.getId(), fileDecryptionProperties.get(), fileCryptoMetaData.getEncryption_algorithm(), Optional.ofNullable(fileCryptoMetaData.getKey_metadata()))); + footerDecryptor = decryptionContext.get().getFooterDecryptor(); + aad = AesCipherUtils.createFooterAAD(decryptionContext.get().getFileAad()); + } + + FileMetaData fileMetaData = readFileMetaData(metadataStream, footerDecryptor, aad); + if (!encryptedFooterMode && fileDecryptionProperties.isPresent() && fileMetaData.isSetEncryption_algorithm()) { + // footer is not encrypted, but some columns might be encrypted + decryptionContext = Optional.of(new FileDecryptionContext(dataSource.getId(), fileDecryptionProperties.get(), fileMetaData.getEncryption_algorithm(), Optional.ofNullable(fileMetaData.getFooter_signing_key_metadata()))); + if (fileDecryptionProperties.get().isCheckFooterIntegrity()) { + // verify footer integrity + verifyFooterIntegrity(dataSource, metadataStream, decryptionContext.get(), metadataLength); + } + } + + ParquetMetadata parquetMetadata = new ParquetMetadata(fileMetaData, dataSource.getId(), decryptionContext); validateFileMetadata(dataSource.getId(), parquetMetadata.getFileMetaData(), parquetWriteValidation); return parquetMetadata; } @@ -219,4 +254,26 @@ private static void validateFileMetadata(ParquetDataSourceId dataSourceId, FileM Optional.ofNullable(fileMetaData.getKeyValueMetaData().get("writer.time.zone"))); writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); } + + private static void verifyFooterIntegrity(ParquetDataSource dataSource, SliceInput metadataStream, FileDecryptionContext decryptionContext, int metadataLength) + { + byte[] nonce = new byte[NONCE_LENGTH]; + metadataStream.read(nonce); + + byte[] gcmTag = new byte[GCM_TAG_LENGTH]; + metadataStream.read(gcmTag); + + // read only the serialized footer without the tags + int footerSignatureLength = NONCE_LENGTH + GCM_TAG_LENGTH; + byte[] footer = new byte[metadataLength - footerSignatureLength]; + metadataStream.setPosition(0); + metadataStream.read(footer, 0, footer.length); + byte[] signedFooterAAD = AesCipherUtils.createFooterAAD(decryptionContext.getFileAad()); + + AesGcmEncryptor footerSigner = decryptionContext.getFooterEncryptor(); + byte[] encryptedFooterBytes = footerSigner.encrypt(false, footer, nonce, signedFooterAAD); + byte[] calculatedTag = new byte[GCM_TAG_LENGTH]; + arraycopy(encryptedFooterBytes, encryptedFooterBytes.length - GCM_TAG_LENGTH, calculatedTag, 0, GCM_TAG_LENGTH); + validateParquetCrypto(Arrays.equals(gcmTag, calculatedTag), dataSource.getId(), "Signature mismatch in plaintext footer"); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index d8ec35c52fbe..0d9c85d8727e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -16,17 +16,24 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; import io.trino.parquet.DataPage; import io.trino.parquet.DataPageV1; import io.trino.parquet.DataPageV2; import io.trino.parquet.DictionaryPage; import io.trino.parquet.Page; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.CompressionCodec; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.OffsetIndex; import java.io.IOException; @@ -35,6 +42,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.parquet.ParquetCompressionUtils.decompress; import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages; import static java.util.Objects.requireNonNull; @@ -46,9 +54,14 @@ public final class PageReader private final boolean hasOnlyDictionaryEncodedPages; private final boolean hasNoNulls; private final PeekingIterator compressedPages; + private final Optional blockDecryptor; private boolean dictionaryAlreadyRead; private int dataPageReadCount; + @Nullable + private byte[] dataPageAad; + @Nullable + private byte[] dictionaryPageAad; public static PageReader createPageReader( ParquetDataSourceId dataSourceId, @@ -56,7 +69,8 @@ public static PageReader createPageReader( ColumnChunkMetadata metadata, ColumnDescriptor columnDescriptor, @Nullable OffsetIndex offsetIndex, - Optional fileCreatedBy) + Optional fileCreatedBy, + Optional decryptionContext) { // Parquet schema may specify a column definition as OPTIONAL even though there are no nulls in the actual data. // Row-group column statistics can be used to identify such cases and switch to faster non-nullable read @@ -64,20 +78,25 @@ public static PageReader createPageReader( Statistics columnStatistics = metadata.getStatistics(); boolean hasNoNulls = columnStatistics != null && columnStatistics.getNumNulls() == 0; boolean hasOnlyDictionaryEncodedPages = isOnlyDictionaryEncodingPages(metadata); + Optional columnDecryptionContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(ColumnPath.get(columnDescriptor.getPath()))); ParquetColumnChunkIterator compressedPages = new ParquetColumnChunkIterator( dataSourceId, fileCreatedBy, columnDescriptor, metadata, columnChunk, - offsetIndex); + offsetIndex, + columnDecryptionContext); return new PageReader( dataSourceId, metadata.getCodec().getParquetCompressionCodec(), compressedPages, hasOnlyDictionaryEncodedPages, - hasNoNulls); + hasNoNulls, + columnDecryptionContext, + metadata.getRowGroupOrdinal(), + metadata.getColumnOrdinal()); } @VisibleForTesting @@ -86,13 +105,21 @@ public PageReader( CompressionCodec codec, Iterator compressedPages, boolean hasOnlyDictionaryEncodedPages, - boolean hasNoNulls) + boolean hasNoNulls, + Optional decryptionContext, + int rowGroupOrdinal, + int columnOrdinal) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.codec = codec; this.compressedPages = Iterators.peekingIterator(compressedPages); this.hasOnlyDictionaryEncodedPages = hasOnlyDictionaryEncodedPages; this.hasNoNulls = hasNoNulls; + this.blockDecryptor = decryptionContext.map(ColumnDecryptionContext::dataDecryptor); + if (blockDecryptor.isPresent()) { + dataPageAad = AesCipherUtils.createModuleAAD(decryptionContext.get().fileAad(), ModuleType.DataPage, rowGroupOrdinal, columnOrdinal, 0); + dictionaryPageAad = AesCipherUtils.createModuleAAD(decryptionContext.get().fileAad(), ModuleType.DictionaryPage, rowGroupOrdinal, columnOrdinal, -1); + } } public boolean hasNoNulls() @@ -114,18 +141,20 @@ public DataPage readPage() checkState(compressedPage instanceof DataPage, "Found page %s instead of a DataPage", compressedPage); dataPageReadCount++; try { + if (blockDecryptor.isPresent()) { + AesCipherUtils.quickUpdatePageAAD(dataPageAad, ((DataPage) compressedPage).getPageIndex()); + } + Slice slice = decryptSliceIfNeeded(compressedPage.getSlice(), dataPageAad); if (compressedPage instanceof DataPageV1 dataPageV1) { - if (!arePagesCompressed()) { - return dataPageV1; - } return new DataPageV1( - decompress(dataSourceId, codec, dataPageV1.getSlice(), dataPageV1.getUncompressedSize()), + !arePagesCompressed() ? slice : decompress(dataSourceId, codec, slice, dataPageV1.getUncompressedSize()), dataPageV1.getValueCount(), dataPageV1.getUncompressedSize(), dataPageV1.getFirstRowIndex(), dataPageV1.getRepetitionLevelEncoding(), dataPageV1.getDefinitionLevelEncoding(), - dataPageV1.getValueEncoding()); + dataPageV1.getValueEncoding(), + dataPageV1.getPageIndex()); } DataPageV2 dataPageV2 = (DataPageV2) compressedPage; if (!dataPageV2.isCompressed()) { @@ -141,11 +170,12 @@ public DataPage readPage() dataPageV2.getRepetitionLevels(), dataPageV2.getDefinitionLevels(), dataPageV2.getDataEncoding(), - decompress(dataSourceId, codec, dataPageV2.getSlice(), uncompressedSize), + decompress(dataSourceId, codec, slice, uncompressedSize), dataPageV2.getUncompressedSize(), dataPageV2.getFirstRowIndex(), dataPageV2.getStatistics(), - false); + false, + dataPageV2.getPageIndex()); } catch (IOException e) { throw new RuntimeException("Could not decompress page", e); @@ -162,8 +192,9 @@ public DictionaryPage readDictionaryPage() } try { DictionaryPage compressedDictionaryPage = (DictionaryPage) compressedPages.next(); + Slice slice = decryptSliceIfNeeded(compressedDictionaryPage.getSlice(), dictionaryPageAad); return new DictionaryPage( - decompress(dataSourceId, codec, compressedDictionaryPage.getSlice(), compressedDictionaryPage.getUncompressedSize()), + decompress(dataSourceId, codec, slice, compressedDictionaryPage.getUncompressedSize()), compressedDictionaryPage.getDictionarySize(), compressedDictionaryPage.getEncoding()); } @@ -199,4 +230,14 @@ private void verifyDictionaryPageRead() { checkArgument(dictionaryAlreadyRead, "Dictionary has to be read first"); } + + private Slice decryptSliceIfNeeded(Slice slice, byte[] aad) + throws IOException + { + if (blockDecryptor.isEmpty()) { + return slice; + } + byte[] plainText = blockDecryptor.get().decrypt(slice.getBytes(), aad); + return wrappedBuffer(plainText); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java index 235c1b2d3d76..6a67adcd44a6 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunkIterator.java @@ -19,10 +19,14 @@ import io.trino.parquet.Page; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.crypto.AesCipherUtils; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.metadata.ColumnChunkMetadata; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; +import org.apache.parquet.format.BlockCipher.Decryptor; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; import org.apache.parquet.format.DictionaryPageHeader; @@ -48,17 +52,22 @@ public final class ParquetColumnChunkIterator private final ColumnChunkMetadata metadata; private final ChunkedInputStream input; private final OffsetIndex offsetIndex; + private final Optional decryptionContext; private long valueCount; private int dataPageCount; + private byte[] dataPageHeaderAad; + private boolean dictionaryWasRead; + public ParquetColumnChunkIterator( ParquetDataSourceId dataSourceId, Optional fileCreatedBy, ColumnDescriptor descriptor, ColumnChunkMetadata metadata, ChunkedInputStream input, - @Nullable OffsetIndex offsetIndex) + @Nullable OffsetIndex offsetIndex, + Optional decryptionContext) { this.dataSourceId = requireNonNull(dataSourceId, "dataSourceId is null"); this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -66,6 +75,7 @@ public ParquetColumnChunkIterator( this.metadata = requireNonNull(metadata, "metadata is null"); this.input = requireNonNull(input, "input is null"); this.offsetIndex = offsetIndex; + this.decryptionContext = requireNonNull(decryptionContext, "decryptionContext is null"); } @Override @@ -80,7 +90,16 @@ public Page next() checkState(hasNext(), "No more data left to read in column (%s), metadata (%s), valueCount %s, dataPageCount %s", descriptor, metadata, valueCount, dataPageCount); try { - PageHeader pageHeader = readPageHeader(); + byte[] pageHeaderAAD = null; + if (decryptionContext.isPresent()) { + if (!dictionaryWasRead && metadata.hasDictionaryPage()) { + pageHeaderAAD = getDictionaryPageHeaderAAD(); + } + else { + pageHeaderAAD = getPageHeaderAAD(); + } + } + PageHeader pageHeader = readPageHeader(decryptionContext.map(ColumnDecryptionContext::metadataDecryptor).orElse(null), pageHeaderAAD); int uncompressedPageSize = pageHeader.getUncompressed_page_size(); int compressedPageSize = pageHeader.getCompressed_page_size(); Page result = null; @@ -90,13 +109,14 @@ public Page next() throw new ParquetCorruptionException(dataSourceId, "Column (%s) has a dictionary page after the first position in column chunk", descriptor); } result = readDictionaryPage(pageHeader, pageHeader.getUncompressed_page_size(), pageHeader.getCompressed_page_size()); + dictionaryWasRead = true; break; case DATA_PAGE: - result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV1(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; case DATA_PAGE_V2: - result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex)); + result = readDataPageV2(pageHeader, uncompressedPageSize, compressedPageSize, getFirstRowIndex(dataPageCount, offsetIndex), dataPageCount); ++dataPageCount; break; default: @@ -110,10 +130,39 @@ public Page next() } } - private PageHeader readPageHeader() + private byte[] getPageHeaderAAD() + { + checkState(decryptionContext.isPresent()); + + if (dataPageHeaderAad != null) { + AesCipherUtils.quickUpdatePageAAD(dataPageHeaderAad, dataPageCount); + return dataPageHeaderAad; + } + + dataPageHeaderAad = AesCipherUtils.createModuleAAD( + decryptionContext.get().fileAad(), + ModuleType.DataPageHeader, + metadata.getRowGroupOrdinal(), + metadata.getColumnOrdinal(), + dataPageCount); + return dataPageHeaderAad; + } + + private byte[] getDictionaryPageHeaderAAD() + { + checkState(decryptionContext.isPresent()); + return AesCipherUtils.createModuleAAD( + decryptionContext.get().fileAad(), + ModuleType.DictionaryPageHeader, + metadata.getRowGroupOrdinal(), + metadata.getColumnOrdinal(), + -1); + } + + private PageHeader readPageHeader(Decryptor headerBlockDecryptor, byte[] pageHeaderAAD) throws IOException { - return Util.readPageHeader(input); + return Util.readPageHeader(input, headerBlockDecryptor, pageHeaderAAD); } private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoFar) @@ -139,7 +188,8 @@ private DataPageV1 readDataPageV1( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeader dataHeaderV1 = pageHeader.getData_page_header(); @@ -151,14 +201,16 @@ private DataPageV1 readDataPageV1( firstRowIndex, getParquetEncoding(Encoding.valueOf(dataHeaderV1.getRepetition_level_encoding().name())), getParquetEncoding(Encoding.valueOf(dataHeaderV1.getDefinition_level_encoding().name())), - getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name()))); + getParquetEncoding(Encoding.valueOf(dataHeaderV1.getEncoding().name())), + pageIndex); } private DataPageV2 readDataPageV2( PageHeader pageHeader, int uncompressedPageSize, int compressedPageSize, - OptionalLong firstRowIndex) + OptionalLong firstRowIndex, + int pageIndex) throws IOException { DataPageHeaderV2 dataHeaderV2 = pageHeader.getData_page_header_v2(); @@ -178,7 +230,8 @@ private DataPageV2 readDataPageV2( fileCreatedBy, Optional.ofNullable(dataHeaderV2.getStatistics()), descriptor.getPrimitiveType()), - dataHeaderV2.isIs_compressed()); + dataHeaderV2.isIs_compressed(), + pageIndex); } private static OptionalLong getFirstRowIndex(int pageIndex, OffsetIndex offsetIndex) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 90c2b8b57a75..9fa3e7268de8 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -33,6 +33,7 @@ import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; import io.trino.parquet.VariantField; +import io.trino.parquet.crypto.FileDecryptionContext; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.PrunedBlockMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -150,6 +151,7 @@ public class ParquetReader private int currentPageId; private long columnIndexRowsFiltered = -1; + private final Optional decryptionContext; public ParquetReader( Optional fileCreatedBy, @@ -162,7 +164,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - Optional writeValidation) + Optional writeValidation, + Optional decryptionContext) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -180,6 +183,7 @@ public ParquetReader( this.maxBatchSize = options.getMaxReadBlockRowCount(); this.columnReaders = new HashMap<>(); this.maxBytesPerCell = new HashMap<>(); + this.decryptionContext = requireNonNull(decryptionContext, "decryptionContext is null"); this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); validateWrite( @@ -668,7 +672,7 @@ private ColumnChunk readPrimitive(PrimitiveField field) } ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup)); columnReader.setPageReader( - createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy), + createPageReader(dataSource.getId(), columnChunkInputStream, metadata, columnDescriptor, offsetIndex, fileCreatedBy, decryptionContext), Optional.ofNullable(rowRanges)); } ColumnChunk columnChunk = columnReader.readPrimitive(); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/TrinoColumnIndexStore.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/TrinoColumnIndexStore.java index fa9b7ae142d5..d08d28b6f857 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/TrinoColumnIndexStore.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/TrinoColumnIndexStore.java @@ -19,6 +19,9 @@ import io.trino.parquet.DiskRange; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.ColumnDecryptionContext; +import io.trino.parquet.crypto.FileDecryptionContext; +import io.trino.parquet.crypto.ModuleType; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.IndexReference; @@ -26,6 +29,7 @@ import io.trino.spi.predicate.TupleDomain; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.format.BlockCipher; import org.apache.parquet.format.Util; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.internal.column.columnindex.ColumnIndex; @@ -47,6 +51,7 @@ import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetMetadataConverter.fromParquetColumnIndex; import static io.trino.parquet.ParquetMetadataConverter.fromParquetOffsetIndex; +import static io.trino.parquet.crypto.AesCipherUtils.createModuleAAD; import static java.util.Objects.requireNonNull; /** @@ -59,6 +64,7 @@ public class TrinoColumnIndexStore private final ParquetDataSource dataSource; private final List columnIndexReferences; private final List offsetIndexReferences; + private final Optional decryptionContext; @Nullable private Map columnIndexStore; @@ -75,9 +81,11 @@ public TrinoColumnIndexStore( ParquetDataSource dataSource, BlockMetadata block, Set columnsRead, - Set columnsFiltered) + Set columnsFiltered, + Optional decryptionContext) { this.dataSource = requireNonNull(dataSource, "dataSource is null"); + this.decryptionContext = requireNonNull(decryptionContext, "decryptionContext is null"); requireNonNull(block, "block is null"); requireNonNull(columnsRead, "columnsRead is null"); requireNonNull(columnsFiltered, "columnsFiltered is null"); @@ -90,13 +98,17 @@ public TrinoColumnIndexStore( columnIndexBuilder.add(new ColumnIndexMetadata( column.getColumnIndexReference(), path, - column.getPrimitiveType())); + column.getPrimitiveType(), + column.getRowGroupOrdinal(), + column.getColumnOrdinal())); } if (column.getOffsetIndexReference() != null && columnsRead.contains(path)) { offsetIndexBuilder.add(new ColumnIndexMetadata( column.getOffsetIndexReference(), path, - column.getPrimitiveType())); + column.getPrimitiveType(), + column.getRowGroupOrdinal(), + column.getColumnOrdinal())); } } this.columnIndexReferences = columnIndexBuilder.build(); @@ -109,14 +121,22 @@ public ColumnIndex getColumnIndex(ColumnPath column) if (columnIndexStore == null) { columnIndexStore = loadIndexes(dataSource, columnIndexReferences, (inputStream, columnMetadata) -> { try { - return fromParquetColumnIndex(columnMetadata.getPrimitiveType(), Util.readColumnIndex(inputStream)); + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(columnMetadata.getPath())); + byte[] aad = columnContext + .map(context -> createModuleAAD(context.fileAad(), ModuleType.ColumnIndex, columnMetadata.getRowGroupOrdinal(), columnMetadata.getColumnOrdinal(), -1)) + .orElse(null); + BlockCipher.Decryptor thriftDecryptor = columnContext + .map(ColumnDecryptionContext::metadataDecryptor) + .orElse(null); + return fromParquetColumnIndex( + columnMetadata.getPrimitiveType(), + Util.readColumnIndex(inputStream, thriftDecryptor, aad)); } catch (IOException e) { throw new RuntimeException(e); } }); } - return columnIndexStore.get(column); } @@ -126,14 +146,20 @@ public OffsetIndex getOffsetIndex(ColumnPath column) if (offsetIndexStore == null) { offsetIndexStore = loadIndexes(dataSource, offsetIndexReferences, (inputStream, columnMetadata) -> { try { - return fromParquetOffsetIndex(Util.readOffsetIndex(inputStream)); + Optional columnContext = decryptionContext.flatMap(context -> context.getColumnDecryptionContext(columnMetadata.getPath())); + byte[] aad = columnContext + .map(context -> createModuleAAD(context.fileAad(), ModuleType.OffsetIndex, columnMetadata.getRowGroupOrdinal(), columnMetadata.getColumnOrdinal(), -1)) + .orElse(null); + BlockCipher.Decryptor thriftDecryptor = columnContext + .map(ColumnDecryptionContext::metadataDecryptor) + .orElse(null); + return fromParquetOffsetIndex(Util.readOffsetIndex(inputStream, thriftDecryptor, aad)); } catch (IOException e) { throw new RuntimeException(e); } }); } - return offsetIndexStore.get(column); } @@ -142,7 +168,8 @@ public static Optional getColumnIndexStore( BlockMetadata blockMetadata, Map, ColumnDescriptor> descriptorsByPath, TupleDomain parquetTupleDomain, - ParquetReaderOptions options) + ParquetReaderOptions options, + Optional decryptionContext) { if (!options.isUseColumnIndex() || parquetTupleDomain.isAll() || parquetTupleDomain.isNone()) { return Optional.empty(); @@ -171,7 +198,7 @@ public static Optional getColumnIndexStore( .map(column -> ColumnPath.get(column.getPath())) .collect(toImmutableSet()); - return Optional.of(new TrinoColumnIndexStore(dataSource, blockMetadata, columnsReadPaths, columnsFilteredPaths)); + return Optional.of(new TrinoColumnIndexStore(dataSource, blockMetadata, columnsReadPaths, columnsFilteredPaths, decryptionContext)); } private static Map loadIndexes( @@ -202,13 +229,17 @@ private static class ColumnIndexMetadata private final DiskRange diskRange; private final ColumnPath path; private final PrimitiveType primitiveType; + private final int rowGroupOrdinal; + private final int columnOrdinal; - private ColumnIndexMetadata(IndexReference indexReference, ColumnPath path, PrimitiveType primitiveType) + private ColumnIndexMetadata(IndexReference indexReference, ColumnPath path, PrimitiveType primitiveType, int rowGroupOrdinal, int columnOrdinal) { requireNonNull(indexReference, "indexReference is null"); this.diskRange = new DiskRange(indexReference.getOffset(), indexReference.getLength()); this.path = requireNonNull(path, "path is null"); this.primitiveType = requireNonNull(primitiveType, "primitiveType is null"); + this.rowGroupOrdinal = rowGroupOrdinal; + this.columnOrdinal = columnOrdinal; } private DiskRange getDiskRange() @@ -225,5 +256,15 @@ private PrimitiveType getPrimitiveType() { return primitiveType; } + + private int getRowGroupOrdinal() + { + return rowGroupOrdinal; + } + + private int getColumnOrdinal() + { + return columnOrdinal; + } } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 0251d897b905..5366c401d0d2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -237,7 +237,7 @@ public void validate(ParquetDataSource input) checkState(validationBuilder.isPresent(), "validation is not enabled"); ParquetWriteValidation writeValidation = validationBuilder.get().build(); try { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.empty(), Optional.of(writeValidation)); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.empty(), Optional.of(writeValidation), Optional.empty()); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { for (SourcePage page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page @@ -294,7 +294,8 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new RuntimeException(exception); }, Optional.empty(), - Optional.of(writeValidation)); + Optional.of(writeValidation), + Optional.empty()); } private void recordValidation(Consumer task) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java index b4cbdcec0b2b..09b7dd79ab9b 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java @@ -226,7 +226,7 @@ public void setup() testData.getColumnNames(), testData.getPages()), ParquetReaderOptions.defaultOptions()); - parquetMetadata = MetadataReader.readFooter(dataSource); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); columnNames = columns.stream() .map(TpchColumn::getColumnName) .collect(toImmutableList()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index 4372aa91cda1..ebb180374394 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -176,6 +176,7 @@ public static ParquetReader createParquetReader( return new RuntimeException(exception); }, Optional.of(parquetPredicate), + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/crypto/TestParquetEncryption.java b/lib/trino-parquet/src/test/java/io/trino/parquet/crypto/TestParquetEncryption.java new file mode 100644 index 000000000000..5997155e66d0 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/crypto/TestParquetEncryption.java @@ -0,0 +1,1038 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.crypto; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.PrimitiveField; +import io.trino.parquet.metadata.ColumnChunkMetadata; +import io.trino.parquet.metadata.ParquetMetadata; +import io.trino.parquet.predicate.TupleDomainParquetPredicate; +import io.trino.parquet.reader.FileParquetDataSource; +import io.trino.parquet.reader.MetadataReader; +import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.RowGroupInfo; +import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.connector.SourcePage; +import io.trino.spi.predicate.TupleDomain; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.crypto.ColumnEncryptionProperties; +import org.apache.parquet.crypto.FileEncryptionProperties; +import org.apache.parquet.crypto.ParquetCipher; +import org.apache.parquet.example.data.Group; +import org.apache.parquet.example.data.simple.SimpleGroupFactory; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.example.ExampleParquetWriter; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.MessageTypeParser; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Types; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.Supplier; + +import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.predicate.PredicateUtils.getFilteredRowGroups; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.type.IntegerType.INTEGER; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.nio.file.Files.createTempFile; +import static org.apache.parquet.hadoop.ParquetFileWriter.Mode.OVERWRITE; +import static org.apache.parquet.hadoop.metadata.CompressionCodecName.SNAPPY; +import static org.apache.parquet.hadoop.metadata.CompressionCodecName.UNCOMPRESSED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.joda.time.DateTimeZone.UTC; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +/** + * Unit‑tests exercising Trino’s PME path. + */ +// ExampleParquetWriter is not thread-safe +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public final class TestParquetEncryption +{ + private static final ColumnPath AGE_PATH = ColumnPath.fromDotString("age"); + private static final ColumnPath ID_PATH = ColumnPath.fromDotString("id"); + + private static final byte[] KEY_AGE = "colKeyIs16ByteA?".getBytes(UTF_8); + private static final byte[] KEY_ID = "colKeyIs16ByteB?".getBytes(UTF_8); + private static final byte[] KEY_FOOT = "footKeyIs16Byte?".getBytes(UTF_8); + + // one‑column schema + private static final MessageType AGE_SCHEMA = + MessageTypeParser.parseMessageType("message doc { required int32 age; }"); + + // two‑column schema + private static final MessageType TWO_COL_SCHEMA = MessageTypeParser.parseMessageType( + "message doc { required int32 age; required int32 id; }"); + + /** + * Column encryption only – footer left in plaintext. + */ + @ParameterizedTest(name = "checkFooterIntegrity={0}, compressed={1}") + @CsvSource({ + "false,false", + "false,true", + "true,false", + "true,true" + }) + void columnOnlyFooterPlaintext(boolean checkFooterIntegrity, boolean compressed) + throws IOException + { + File file = createTempFile("pme‑col‑only", ".parquet").toFile(); + file.deleteOnExit(); + + writeSingleColumnFile(file, compressed, /*encryptFooter*/ false); + + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(checkFooterIntegrity ? Optional.of(KEY_FOOT) : Optional.empty(), Optional.of(KEY_AGE), Optional.empty()), + checkFooterIntegrity, + Optional.empty()); + + verifySequence(values, 100); + } + + /** + * Column+ footer encryption (same column as above). + */ + @ParameterizedTest(name = "compressed={0}") + @CsvSource({"false", "true"}) + void columnAndFooter(boolean compressed) + throws IOException + { + File file = createTempFile("pme‑col‑foot", ".parquet").toFile(); + file.deleteOnExit(); + + writeSingleColumnFile(file, compressed, /*encryptFooter*/ true); + + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty()); + + verifySequence(values, 100); + } + + /** + * Single column encrypted with footer column. + */ + @ParameterizedTest(name = "encryptFooter={0}") + @CsvSource({"false", "true"}) + void columnEncryptedWithFooterKey(boolean encryptFooter) + throws IOException + { + File file = createTempFile("pme‑footerKey‑col", ".parquet").toFile(); + file.deleteOnExit(); + + // write: column encrypted with footer key, footer encrypted as usual + writeSingleColumnFile(file, /*compressed*/ false, encryptFooter, /*columnEncryptedWithFooterKey*/ true, Optional.empty(), () -> range(0, 100), OptionalInt.empty(), OptionalInt.empty()); + + // read: supply ONLY the footer key (no column key necessary) + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.empty(), Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty()); + + verifySequence(values, 100); // data round‑trip OK + } + + @ParameterizedTest(name = "supplyPrefix={0}") + @CsvSource({"true", "false"}) + void aadPrefixRoundTrip(boolean supplyPrefix) + throws IOException + { + File file = createTempFile("pme‑aad", ".parquet").toFile(); + file.deleteOnExit(); + + byte[] prefix = "fileAAD‑prefix".getBytes(UTF_8); + writeSingleColumnFile(file, /*compressed*/ true, /*encryptFooter*/ true, /*columnEncryptedWithFooterKey*/ false, Optional.of(prefix), () -> range(0, 100), OptionalInt.empty(), OptionalInt.empty()); + + // build FileDecryptionProperties with / without the required prefix + FileDecryptionProperties.Builder properties = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty())); + + if (supplyPrefix) { + properties.withAadPrefix(prefix); + // correct prefix => read succeeds + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty()), + true, // footer integrity + Optional.of(prefix)); // pass prefix + + verifySequence(values, 100); + } + else { + // missing / wrong prefix => should fail while reading footer + assertThatThrownBy(() -> readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty()), + true, + Optional.empty())) // NO prefix + .isInstanceOfAny(ParquetCryptoException.class); + } + } + + @ParameterizedTest(name = "rows={0}, rowGroupSize={1}, pageSize={2}, encryptFooter={3}, compressed={4}") + @CsvSource({ + // rows, rowGroupSizeBytes, pageSizeBytes, encryptFooter, compressed + "2000, 1024, 128, true, false", + "1500, 2048, 256, false, true", + }) + void multipleRowGroups(int rows, int rowGroupSize, int pageSize, boolean encryptFooter, boolean compressed) + throws IOException + { + File file = createTempFile("pme-multirg", ".parquet").toFile(); + file.deleteOnExit(); + + writeSingleColumnFile( + file, + compressed, + encryptFooter, + /* columnEncryptedWithFooterKey */ false, + Optional.empty(), + () -> range(0, rows), + OptionalInt.of(rowGroupSize), + OptionalInt.of(pageSize)); + + // sanity: we really created >1 row group + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + FileDecryptionProperties properties = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever( + /* footer */ encryptFooter ? Optional.of(KEY_FOOT) : Optional.empty(), + /* column */ Optional.of(KEY_AGE), + Optional.empty())) + .withCheckFooterIntegrity(encryptFooter) + .build(); + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(properties)); + assertThat(metadata.getBlocks().size()).isGreaterThan(1); + } + + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(encryptFooter ? Optional.of(KEY_FOOT) : Optional.empty(), Optional.of(KEY_AGE), Optional.empty()), + encryptFooter, + Optional.empty()); + + verifySequence(values, rows); + } + + /** + * Two columns; footer encrypted; one column plaintext, other encrypted. + */ + @Test + void mixedEncryptedAndPlaintextColumns() + throws IOException + { + File file = createTempFile("pme‑mixed‑cols", ".parquet").toFile(); + file.deleteOnExit(); + + // age column encrypted, id column plaintext + writeTwoColumnFile(file, /*encryptAge*/ true, /*encryptId*/ false); + + // Provide keys for both footer & encrypted age column + Map> data = readTwoColumnFile(file, new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty())); + + verifySequence(data.get("age"), 100); + verifyReverseSequence(data.get("id"), 100); + } + + /** + * Two columns; different keys; refuse access to one column. + */ + @Test + void oneColumnAccessibleOtherInaccessible() + throws IOException + { + File file = createTempFile("pme‑locked‑col", ".parquet").toFile(); + file.deleteOnExit(); + + // both columns encrypted with different keys + writeTwoColumnFile(file, /*encryptAge*/ true, /*encryptId*/ true); + + // reader has footer key + KEY_AGE only, not KEY_ID + TestingKeyRetriever retriever = new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty()); + + List values = readSingleColumnFile(file, retriever, true, Optional.empty()); + + verifySequence(values, 100); + } + + /** + * Single column encrypted, written with dictionary encoding. + */ + @Test + void dictionaryEncodedEncryptedColumn() + throws IOException + { + File file = createTempFile("pme‑dict", ".parquet").toFile(); + file.deleteOnExit(); + + // create a very small domain so writer chooses dictionary + writeSingleColumnFile( + file, + /*compressed*/ false, + /*encryptFooter*/ true, + /*columnEncryptedWithFooterKey*/ false, + Optional.empty(), + () -> List.of(1, 2, 3, 1, 2, 3), + OptionalInt.empty(), + OptionalInt.empty()); // limited distincts + + List values = readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty()), + true, + Optional.empty()); + + assertThat(new HashSet<>(values)).containsExactlyInAnyOrder(1, 2, 3); + + // ── extra assertion: verify a dictionary really exists ── + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + FileDecryptionProperties properties = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.empty())) + .build(); + + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(properties)); + + // first (and only) row‑group → column‑chunk for "age" + ColumnChunkMetadata chunk = + metadata.getBlocks().getFirst().columns().stream() + .filter(column -> column.getPath().equals(AGE_PATH)) + .findFirst() + .orElseThrow(); + + // 1) dictionary page must be present + assertThat(chunk.getDictionaryPageOffset()).isGreaterThan(0); + + // 2) PLAIN_DICTIONARY (or RLE_DICTIONARY in v2) must be in the encoding list + assertThat(chunk.getEncodings()).anyMatch(Encoding::usesDictionary); + + // 3) (optional) Trino helper: all data pages are dictionary‑encoded + assertThat(io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages(chunk)).isTrue(); + } + } + + @Test + void readFailsWithoutFooterKey() + throws IOException + { + File file = createTempFile("pme-no-footer-key", ".parquet").toFile(); + file.deleteOnExit(); + + // footer + column encrypted + writeSingleColumnFile(file, /*compressed*/ false, /*encryptFooter*/ true); + + assertThatThrownBy(() -> readSingleColumnFile( + file, + new TestingKeyRetriever(/* footer */ Optional.empty(), + /* column */ Optional.of(KEY_AGE), + Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty())) + .isInstanceOf(ParquetCryptoException.class); + } + + /** + * Footer key is present, but the column key is missing. + * Footer decrypts; column decryption must fail. + */ + @Test + void readFailsWithoutColumnKey() + throws IOException + { + File file = createTempFile("pme-no-column-key", ".parquet").toFile(); + file.deleteOnExit(); + + // footer + column encrypted + writeSingleColumnFile(file, /*compressed*/ false, /*encryptFooter*/ true); + + assertThatThrownBy(() -> readSingleColumnFile( + file, + new TestingKeyRetriever(/* footer */ Optional.of(KEY_FOOT), + /* column */ Optional.empty(), + Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty())) + .isInstanceOf(ParquetCryptoException.class); + } + + @Test + void readFailsWithInvalidFooterKey() + throws IOException + { + File file = createTempFile("pme-bad-footer", ".parquet").toFile(); + file.deleteOnExit(); + + writeSingleColumnFile(file, /*compressed*/ false, /*encryptFooter*/ true); + + byte[] wrongFooter = "thisIsTheBadKey!".getBytes(UTF_8); // 16 bytes != real key + + assertThatThrownBy(() -> readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(wrongFooter), Optional.of(KEY_AGE), Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty())) + .isInstanceOf(ParquetCryptoException.class); + } + + /** + * Footer key is correct but the column key is wrong. + * Footer decrypts, column decryption must still fail. + */ + @Test + void readFailsWithInvalidColumnKey() + throws IOException + { + File file = createTempFile("pme-bad-column", ".parquet").toFile(); + file.deleteOnExit(); + + writeSingleColumnFile(file, /*compressed*/ false, /*encryptFooter*/ true); + + byte[] wrongColumn = "badColumnKey123".getBytes(UTF_8); // 16 bytes + + assertThatThrownBy(() -> readSingleColumnFile( + file, + new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(wrongColumn), Optional.empty()), + /*checkFooterIntegrity*/ true, + Optional.empty())) + .isInstanceOf(ParquetCryptoException.class); + } + + @Test + void encryptedDictionaryPruningTwoColumns() + throws IOException + { + File file = createTempFile("pme-dict-2cols", ".parquet").toFile(); + file.deleteOnExit(); + + int missingAge = 7; + int missingId = 3; + writeTwoColumnEncryptedDictionaryFile(file, missingAge, missingId); + + // Open with footer key + AGE key only (no ID key) + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + FileDecryptionProperties props = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever( + Optional.of(KEY_FOOT), // footer + Optional.of(KEY_AGE), // age column key + Optional.empty())) // NO id key + .withCheckFooterIntegrity(true) + .build(); + + ParquetMetadata metadata = MetadataReader.readFooter( + source, Optional.empty(), Optional.empty(), Optional.of(props)); + + ColumnDescriptor age = new ColumnDescriptor( + new String[] {"age"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("age"), + 0, 0); + + ColumnDescriptor id = new ColumnDescriptor( + new String[] {"id"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("id"), + 0, 0); + + // ——— Predicate on accessible column (age = missingAge) → dictionary-based pruning to 0 ——— + TupleDomain domainAge = TupleDomain.withColumnDomains(ImmutableMap.of(age, singleValue(INTEGER, (long) missingAge))); + + TupleDomainParquetPredicate predicateAge = new TupleDomainParquetPredicate( + domainAge, ImmutableList.of(age), UTC); + + List groupsAge = getFilteredRowGroups( + 0, + source.getEstimatedSize(), + source, + metadata, + List.of(domainAge), + List.of(predicateAge), + ImmutableMap.of(ImmutableList.of("age"), age), + UTC, + 200, + ParquetReaderOptions.builder().build()); + + // No row-groups should pass after dictionary pruning + assertThat(groupsAge).isEmpty(); + + // ——— Predicate on inaccessible column (id = missingId) → should fail (no column key) ——— + TupleDomain domainId = TupleDomain.withColumnDomains(ImmutableMap.of(id, singleValue(INTEGER, (long) missingId))); + + TupleDomainParquetPredicate predicateId = new TupleDomainParquetPredicate(domainId, ImmutableList.of(id), UTC); + + assertThatThrownBy(() -> getFilteredRowGroups( + 0, + source.getEstimatedSize(), + source, + metadata, + List.of(domainId), + List.of(predicateId), + ImmutableMap.of(ImmutableList.of("id"), id), + UTC, + 200, + ParquetReaderOptions.builder().build())) + // keep this broad; different layers may surface different messages + .hasMessageMatching("(?s).*access.*column.*id.*|.*decrypt.*id.*|.*key.*id.*"); + } + } + + @Test + void encryptedBloomFilterPruningTwoColumns() + throws IOException + { + File file = createTempFile("pme-bloom-2cols", ".parquet").toFile(); + file.deleteOnExit(); + + int missingAge = 7; + int missingId = 3; + writeTwoColumnEncryptedBloomFile(file, missingAge, missingId); + assertBloomFiltersPresent(file); + + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + // Footer + AGE key only (no ID key) + FileDecryptionProperties props = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever( + Optional.of(KEY_FOOT), // footer + Optional.of(KEY_AGE), // age column key + Optional.empty())) // NO id key + .withCheckFooterIntegrity(true) + .build(); + + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(props)); + + ColumnDescriptor age = new ColumnDescriptor( + new String[] {"age"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("age"), + 0, 0); + + ColumnDescriptor id = new ColumnDescriptor( + new String[] {"id"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("id"), + 0, 0); + + // --- Predicate on accessible column (age == missingAge) → Bloom filter should prune to 0 + TupleDomain domainAge = TupleDomain.withColumnDomains( + ImmutableMap.of(age, singleValue(INTEGER, (long) missingAge))); + TupleDomainParquetPredicate predicateAge = new TupleDomainParquetPredicate(domainAge, ImmutableList.of(age), UTC); + + List groupsAge = getFilteredRowGroups( + 0, + source.getEstimatedSize(), + source, + metadata, + List.of(domainAge), + List.of(predicateAge), + ImmutableMap.of(ImmutableList.of("age"), age), + UTC, + 200, + ParquetReaderOptions.builder().build()); + + assertThat(groupsAge).isEmpty(); // pruned by encrypted Bloom filter + + // Sanity: present value should not prune to 0 + TupleDomain domainAgePresent = TupleDomain.withColumnDomains( + ImmutableMap.of(age, singleValue(INTEGER, 5L))); + TupleDomainParquetPredicate predicateAgePresent = new TupleDomainParquetPredicate(domainAgePresent, ImmutableList.of(age), UTC); + List groupsAgePresent = getFilteredRowGroups( + 0, + source.getEstimatedSize(), + source, + metadata, + List.of(domainAgePresent), + List.of(predicateAgePresent), + ImmutableMap.of(ImmutableList.of("age"), age), + UTC, + 200, + ParquetReaderOptions.builder().build()); + assertThat(groupsAgePresent).isNotEmpty(); + + // --- Predicate on inaccessible column (id == missingId) → should fail (no column key) + TupleDomain domainId = TupleDomain.withColumnDomains( + ImmutableMap.of(id, singleValue(INTEGER, (long) missingId))); + TupleDomainParquetPredicate predicateId = new TupleDomainParquetPredicate(domainId, ImmutableList.of(id), UTC); + + assertThatThrownBy(() -> getFilteredRowGroups( + 0, + source.getEstimatedSize(), + source, + metadata, + List.of(domainId), + List.of(predicateId), + ImmutableMap.of(ImmutableList.of("id"), id), + UTC, + 200, + ParquetReaderOptions.builder().build())) + .hasMessageMatching("(?s).*access.*column.*id.*|.*decrypt.*id.*|.*key.*id.*"); + } + } + + /** + * Single‑column writer with knobs. + */ + private static void writeSingleColumnFile(File target, boolean compressed, boolean encryptFooter) + throws IOException + { + writeSingleColumnFile(target, compressed, encryptFooter, false, Optional.empty(), () -> range(0, 100), OptionalInt.empty(), OptionalInt.empty()); + } + + /** + * Overload that takes explicit values supplier (for dictionary case). + */ + private static void writeSingleColumnFile( + File target, + boolean compressed, + boolean encryptFooter, + boolean columnEncryptedWithFooterKey, + Optional aadPrefix, + Supplier> valuesSupplier, + OptionalInt rowGroupSizeBytes, + OptionalInt pageSizeBytes) + throws IOException + { + ColumnEncryptionProperties.Builder ageEncryptionBuilder = ColumnEncryptionProperties.builder(AGE_PATH); + if (!columnEncryptedWithFooterKey) { + ageEncryptionBuilder.withKey(KEY_AGE); + } + ColumnEncryptionProperties ageEncryption = ageEncryptionBuilder.build(); + + FileEncryptionProperties.Builder fileEncryption = FileEncryptionProperties.builder(KEY_FOOT) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(ImmutableMap.of(AGE_PATH, ageEncryption)); + + aadPrefix.ifPresent(prefix -> { + fileEncryption.withAADPrefix(prefix); + fileEncryption.withoutAADPrefixStorage(); + }); + + if (!encryptFooter) { + fileEncryption.withPlaintextFooter(); + } + + ExampleParquetWriter.Builder builder = ExampleParquetWriter + .builder(new Path(target.getAbsolutePath())) + .withType(AGE_SCHEMA) + .withConf(new Configuration()) + .withEncryption(fileEncryption.build()) + .withWriteMode(OVERWRITE) + .withCompressionCodec(compressed ? SNAPPY : UNCOMPRESSED); + + rowGroupSizeBytes.ifPresent(builder::withRowGroupSize); + // tiny page => even 100 ints give us ≥ 2 data pages + builder.withPageSize(pageSizeBytes.orElse(64)); + + try (ParquetWriter writer = builder.build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(AGE_SCHEMA); + for (int value : valuesSupplier.get()) { + writer.write(factory.newGroup().append("age", value)); + } + } + } + + /** + * Two‑column writer. + */ + private static void writeTwoColumnFile( + File target, + boolean encryptAge, + boolean encryptId) + throws IOException + { + ImmutableMap.Builder columnMap = ImmutableMap.builder(); + if (encryptAge) { + columnMap.put(AGE_PATH, ColumnEncryptionProperties.builder(AGE_PATH).withKey(KEY_AGE).build()); + } + if (encryptId) { + columnMap.put(ID_PATH, ColumnEncryptionProperties.builder(ID_PATH).withKey(KEY_ID).build()); + } + + FileEncryptionProperties fileEncryption = FileEncryptionProperties.builder(KEY_FOOT) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(columnMap.buildOrThrow()) + .build(); + + ExampleParquetWriter.Builder builder = ExampleParquetWriter.builder(new Path(target.getAbsolutePath())) + .withType(TWO_COL_SCHEMA) + .withConf(new Configuration()) + .withEncryption(fileEncryption) + .withWriteMode(OVERWRITE) + // tiny page => even 100 ints give us ≥ 2 data pages + .withPageSize(64); + + try (ParquetWriter writer = builder.build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(TWO_COL_SCHEMA); + for (int i = 0; i < 100; i++) { + writer.write(factory.newGroup().append("id", 100 - i).append("age", i)); + } + } + } + + private static void writeTwoColumnEncryptedDictionaryFile(File target, int missingAge, int missingId) + throws IOException + { + FileEncryptionProperties fileEncryption = FileEncryptionProperties.builder(KEY_FOOT) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(ImmutableMap.of( + AGE_PATH, ColumnEncryptionProperties.builder(AGE_PATH).withKey(KEY_AGE).build(), + ID_PATH, ColumnEncryptionProperties.builder(ID_PATH).withKey(KEY_ID).build())) + .build(); + + ExampleParquetWriter.Builder builder = ExampleParquetWriter.builder(new Path(target.getAbsolutePath())) + .withType(TWO_COL_SCHEMA) + .withConf(new Configuration()) + .withEncryption(fileEncryption) + .withWriteMode(OVERWRITE) + .withPageSize(1024); // small pages → strong chance of dictionary encoding + + writeSampleData(builder, missingAge, missingId); + assertDictionariesPresent(target); + } + + private static void assertDictionariesPresent(File file) + throws IOException + { + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + FileDecryptionProperties properties = FileDecryptionProperties.builder() + // for metadata inspection we provide BOTH column keys + footer key + .withKeyRetriever(new TestingKeyRetriever( + Optional.of(KEY_FOOT), + Optional.of(KEY_AGE), + Optional.of(KEY_ID))) + .withCheckFooterIntegrity(true) + .build(); + + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(properties)); + + ColumnChunkMetadata age = metadata.getBlocks().getFirst().columns().stream() + .filter(context -> context.getPath().equals(AGE_PATH)) + .findFirst().orElseThrow(); + + ColumnChunkMetadata id = metadata.getBlocks().getFirst().columns().stream() + .filter(context -> context.getPath().equals(ID_PATH)) + .findFirst().orElseThrow(); + + // dictionary page present + assertThat(age.getDictionaryPageOffset()).isGreaterThan(0); + assertThat(id.getDictionaryPageOffset()).isGreaterThan(0); + + // column encodings include a dictionary encoding + assertThat(age.getEncodings()).anyMatch(org.apache.parquet.column.Encoding::usesDictionary); + assertThat(id.getEncodings()).anyMatch(org.apache.parquet.column.Encoding::usesDictionary); + + // (optional) every data page is dictionary-encoded + assertThat(io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages(age)).isTrue(); + assertThat(io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages(id)).isTrue(); + } + } + + private static void writeTwoColumnEncryptedBloomFile(File target, int missingAge, int missingId) + throws IOException + { + // Encryption: both columns with different keys; encrypted footer + FileEncryptionProperties fileEncryption = FileEncryptionProperties.builder(KEY_FOOT) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(ImmutableMap.of( + AGE_PATH, ColumnEncryptionProperties.builder(AGE_PATH).withKey(KEY_AGE).build(), + ID_PATH, ColumnEncryptionProperties.builder(ID_PATH).withKey(KEY_ID).build())) + .build(); + + // Enable bloom filters and select columns + ExampleParquetWriter.Builder builder = ExampleParquetWriter.builder(new Path(target.getAbsolutePath())) + .withType(TWO_COL_SCHEMA) + .withEncryption(fileEncryption) + .withWriteMode(OVERWRITE) + .withPageSize(1024) + // Bloom filters won't be used with dictionary encoding + .withDictionaryEncoding(false) + .withBloomFilterEnabled(true); + + writeSampleData(builder, missingAge, missingId); + assertBloomFiltersPresent(target); + } + + private static void writeSampleData(ExampleParquetWriter.Builder builder, int missingAge, int missingId) + throws IOException + { + try (ParquetWriter writer = builder.build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(TWO_COL_SCHEMA); + for (int i = 0; i < 5000; i++) { + int id = i % 11; + // ensure 'missingId' not present + if (id == missingId) { + id = (id + 1) % 11; + } + int age = i % 11; + // ensure 'missingAge' not present + if (age == missingAge) { + age = (age + 1) % 11; + } + writer.write(factory.newGroup().append("id", id).append("age", age)); + } + } + } + + private static void assertBloomFiltersPresent(File file) + throws IOException + { + try (ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build())) { + // Provide ALL keys for metadata inspection + FileDecryptionProperties properties = FileDecryptionProperties.builder() + .withKeyRetriever(new TestingKeyRetriever( + Optional.of(KEY_FOOT), + Optional.of(KEY_AGE), + Optional.of(KEY_ID))) + .withCheckFooterIntegrity(true) + .build(); + + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(properties)); + + ColumnChunkMetadata age = metadata.getBlocks().getFirst().columns().stream() + .filter(context -> context.getPath().equals(AGE_PATH)) + .findFirst().orElseThrow(); + + ColumnChunkMetadata id = metadata.getBlocks().getFirst().columns().stream() + .filter(context -> context.getPath().equals(ID_PATH)) + .findFirst().orElseThrow(); + + // Bloom filter offsets must be present for both columns + assertThat(age.getBloomFilterOffset()).isGreaterThan(0L); + assertThat(id.getBloomFilterOffset()).isGreaterThan(0L); + } + } + + /** + * Reads the single‑column file and returns the “age” values. + */ + private static List readSingleColumnFile( + File file, + DecryptionKeyRetriever keyRetriever, + boolean checkFooterIntegrity, + Optional aadPrefix) + throws IOException + { + ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build()); + + FileDecryptionProperties.Builder properties = FileDecryptionProperties + .builder() + .withKeyRetriever(keyRetriever) + .withCheckFooterIntegrity(checkFooterIntegrity); + aadPrefix.ifPresent(properties::withAadPrefix); + + ParquetMetadata metadata = MetadataReader.readFooter( + source, Optional.empty(), Optional.empty(), Optional.of(properties.build())); + + ColumnDescriptor descriptor = new ColumnDescriptor( + new String[] {"age"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("age"), + 0, 0); + + Map, ColumnDescriptor> byPath = ImmutableMap.of( + ImmutableList.of("age"), descriptor); + + TupleDomain domain = TupleDomain.all(); + TupleDomainParquetPredicate predicate = new TupleDomainParquetPredicate( + domain, ImmutableList.of(descriptor), UTC); + + List groups = getFilteredRowGroups( + 0, source.getEstimatedSize(), + source, metadata, + List.of(domain), List.of(predicate), + byPath, UTC, 200, ParquetReaderOptions.builder().build()); + + PrimitiveField field = new PrimitiveField(INTEGER, true, descriptor, 0); + io.trino.parquet.Column column = new io.trino.parquet.Column("age", field); + + try (ParquetReader reader = new ParquetReader( + Optional.ofNullable(metadata.getFileMetaData().getCreatedBy()), + List.of(column), + false, + groups, + source, + UTC, + newSimpleAggregatedMemoryContext(), + ParquetReaderOptions.builder().build(), + RuntimeException::new, + Optional.of(predicate), + Optional.empty(), + metadata.getDecryptionContext())) { + List out = new ArrayList<>(); + SourcePage page; + while ((page = reader.nextPage()) != null) { + Block block = page.getBlock(0); + IntArrayBlock ints = (IntArrayBlock) block.getUnderlyingValueBlock(); + for (int i = 0; i < ints.getPositionCount(); i++) { + out.add(block.isNull(i) ? null : ints.getInt(i)); + } + } + return out; + } + finally { + source.close(); + } + } + + /** + * Reads both columns and returns a map “age” → values, “id → values. + */ + private static Map> readTwoColumnFile( + File file, DecryptionKeyRetriever retriever) + throws IOException + { + ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build()); + + FileDecryptionProperties properties = FileDecryptionProperties + .builder().withKeyRetriever(retriever).withCheckFooterIntegrity(true).build(); + + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(properties)); + + ColumnDescriptor ageDescriptor = new ColumnDescriptor( + new String[] {"age"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("age"), 0, 0); + + ColumnDescriptor idDescriptor = new ColumnDescriptor( + new String[] {"id"}, + Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("id"), 0, 0); + + Map, ColumnDescriptor> byPath = ImmutableMap.of( + ImmutableList.of("age"), ageDescriptor, + ImmutableList.of("id"), idDescriptor); + + TupleDomainParquetPredicate predicate = new TupleDomainParquetPredicate( + TupleDomain.all(), ImmutableList.of(ageDescriptor, idDescriptor), UTC); + + List groups = getFilteredRowGroups( + 0, source.getEstimatedSize(), source, metadata, + List.of(TupleDomain.all()), List.of(predicate), + byPath, UTC, 200, ParquetReaderOptions.builder().build()); + + PrimitiveField ageField = new PrimitiveField(INTEGER, true, ageDescriptor, 0); + PrimitiveField idField = new PrimitiveField(INTEGER, true, idDescriptor, 1); + + List columns = List.of( + new io.trino.parquet.Column("age", ageField), + new io.trino.parquet.Column("id", idField)); + + try (ParquetReader reader = new ParquetReader( + Optional.ofNullable(metadata.getFileMetaData().getCreatedBy()), + columns, + false, groups, source, + UTC, + newSimpleAggregatedMemoryContext(), + ParquetReaderOptions.builder().build(), + RuntimeException::new, + Optional.of(predicate), + Optional.empty(), + metadata.getDecryptionContext())) { + List ages = new ArrayList<>(); + List ids = new ArrayList<>(); + + SourcePage page; + while ((page = reader.nextPage()) != null) { + for (int column = 0; column < 2; column++) { + Block block = page.getBlock(column); + IntArrayBlock ints = (IntArrayBlock) block.getUnderlyingValueBlock(); + List values = (column == 0) ? ages : ids; + for (int i = 0; i < ints.getPositionCount(); i++) { + values.add(block.isNull(i) ? null : ints.getInt(i)); + } + } + } + return Map.of("age", ages, "id", ids); + } + finally { + source.close(); + } + } + + private static void verifySequence(List actual, int n) + { + assertThat(actual).hasSize(n); + for (int i = 0; i < n; i++) { + assertThat(actual.get(i)).isEqualTo(i); + } + } + + private static void verifyReverseSequence(List actual, int n) + { + assertThat(actual).hasSize(n); + for (int i = 0; i < n; i++) { + assertThat(actual.get(i)).isEqualTo(100 - i); + } + } + + private static List range(int fromInclusive, int toExclusive) + { + List out = new ArrayList<>(toExclusive - fromInclusive); + for (int i = fromInclusive; i < toExclusive; i++) { + out.add(i); + } + return out; + } + + private static final class TestingKeyRetriever + implements DecryptionKeyRetriever + { + private final Optional footerKey; + private final Optional ageColumnKey; + private final Optional idColumnKey; + + public TestingKeyRetriever(Optional footerKey, Optional ageColumnKey, Optional idColumnKey) + { + this.footerKey = footerKey; + this.ageColumnKey = ageColumnKey; + this.idColumnKey = idColumnKey; + } + + @Override + public Optional getColumnKey(ColumnPath path, Optional keyMetadata) + { + if ("age".equals(path.toDotString())) { + return ageColumnKey; + } + if ("id".equals(path.toDotString())) { + return idColumnKey; + } + return Optional.empty(); + } + + @Override + public Optional getFooterKey(Optional keyMetadata) + { + return footerKey; + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java index 84e0abc7631d..7c9d414a41a3 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderBenchmark.java @@ -105,7 +105,7 @@ public int read() throws IOException { ColumnReader columnReader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, dataPages.iterator(), false, false, Optional.empty(), -1, -1); columnReader.setPageReader(pageReader, Optional.empty()); int rowsRead = 0; while (rowsRead < dataPositions) { @@ -133,7 +133,8 @@ private DataPage createDataPage(ValuesWriter writer, int valuesCount) OptionalLong.empty(), RLE, RLE, - getParquetEncoding(writer.getEncoding())); + getParquetEncoding(writer.getEncoding()), + 0); } protected static void run(Class clazz) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 6a3fccb1e281..e7af59113053 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -285,8 +285,8 @@ public String toString() private static ColumnReaderInput[] getColumnReaderInputs(ColumnReaderProvider columnReaderProvider) { Object[][] definitionLevelsProviders = Arrays.stream(DefinitionLevelsProvider.ofDefinitionLevel( - columnReaderProvider.getField().getDefinitionLevel(), - columnReaderProvider.getField().isRequired())) + columnReaderProvider.getField().getDefinitionLevel(), + columnReaderProvider.getField().isRequired())) .collect(toDataProvider()); PrimitiveField field = columnReaderProvider.getField(); @@ -564,7 +564,10 @@ else if (dictionaryEncoding == DictionaryEncoding.MIXED) { UNCOMPRESSED, inputPages.iterator(), dictionaryEncoding == DictionaryEncoding.ALL || (dictionaryEncoding == DictionaryEncoding.MIXED && testingPages.size() == 1), - false); + false, + Optional.empty(), + -1, + -1); } private static List createDataPages(List testingPages, ValuesWriter encoder, int maxDef, boolean required) @@ -599,7 +602,8 @@ private static DataPage createDataPage(TestingPage testingPage, ValuesWriter enc valueCount * 4, OptionalLong.of(testingPage.pageRowRange().start()), null, - false); + false, + 0); encoder.reset(); return dataPage; } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 8a51ff5f995c..4ff78c4f1a74 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -660,7 +660,8 @@ protected static DataPage createDataPage( OptionalLong.empty(), getParquetEncoding(repetitionWriter.getEncoding()), getParquetEncoding(definitionWriter.getEncoding()), - encoding); + encoding, + 0); } return new DataPageV2( valueCount, @@ -673,7 +674,8 @@ protected static DataPage createDataPage( definitionBytes.length + repetitionBytes.length + valueBytes.length, OptionalLong.empty(), null, - false); + false, + 0); } protected static PageReader getPageReaderMock(List dataPages, @Nullable DictionaryPage dictionaryPage) @@ -699,7 +701,7 @@ protected static PageReader getPageReaderMock(List dataPages, @Nullabl return ((DataPageV2) page).getDataEncoding(); }) .allMatch(encoding -> encoding == PLAIN_DICTIONARY || encoding == RLE_DICTIONARY), - hasNoNulls); + hasNoNulls, Optional.empty(), -1, -1); } private DataPage createDataPage(DataPageVersion version, ParquetEncoding encoding, ValuesWriter writer, int valueCount) @@ -713,7 +715,7 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin { Slice slice = Slices.wrappedBuffer(writer.getBytes().toByteArray()); if (version == V1) { - return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding); + return new DataPageV1(slice, valueCount, slice.length(), firstRowIndex, RLE, BIT_PACKED, encoding, 0); } return new DataPageV2( valueCount, @@ -726,7 +728,8 @@ private DataPage createDataPage(DataPageVersion version, ParquetEncoding encodin slice.length(), firstRowIndex, null, - false); + false, + 0); } private static ValuesWriter getLevelsWriter(int maxLevel, int valueCount) diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java index fc0b69b4dd60..3fbe637274e9 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.net.URISyntaxException; import java.util.List; +import java.util.Optional; import java.util.stream.IntStream; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -48,7 +49,7 @@ public void testReadFloatDouble() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("byte_stream_split_float_and_double.parquet").toURI()), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); readAndCompare(reader, getExpectedValues()); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index e56ac5c16098..ff8bea8c74e5 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -112,7 +112,7 @@ public void testNanosOutsideDayRange() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("int96_timestamps_nanos_outside_day_range.parquet").toURI()), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); SourcePage page = reader.nextPage(); @@ -166,11 +166,12 @@ private void testVariousTimestamps(TimestampType type) slice.length(), OptionalLong.empty(), null, - false); + false, + 0); // Read and assert ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(DateTimeZone.UTC, ParquetReaderOptions.defaultOptions()); ColumnReader reader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext()); - PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false); + PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false, Optional.empty(), -1, -1); reader.setPageReader(pageReader, Optional.empty()); reader.prepareNextRead(valueCount); Block block = reader.readPrimitive().getBlock(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java index 102e2b4fc01b..e75be3185a20 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestPageReader.java @@ -384,7 +384,6 @@ private static byte[] compress(CompressionCodec compressionCodec, byte[] bytes, } private static PageReader createPageReader(int valueCount, CompressionCodec compressionCodec, boolean hasDictionary, List slices) - throws IOException { EncodingStats.Builder encodingStats = new EncodingStats.Builder(); if (hasDictionary) { @@ -409,6 +408,7 @@ private static PageReader createPageReader(int valueCount, CompressionCodec comp columnChunkMetaData, new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, INT32, ""), 0, 0), null, + Optional.empty(), Optional.empty()); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java index 2cd1056755e1..4e2db72c471c 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java @@ -44,6 +44,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetTestUtils.createParquetReader; @@ -76,7 +77,7 @@ public void testColumnReaderMemoryUsage() columnNames, generateInputPages(types, 100, 5)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThan(1); // Verify file has only non-dictionary encodings as dictionary memory usage is already tested in TestFlatColumnReader#testMemoryUsage parquetMetadata.getBlocks().forEach(block -> { @@ -128,7 +129,7 @@ public void testEmptyRowRangesWithColumnIndex() ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("lineitem_sorted_by_shipdate/data.parquet").toURI()), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks()).hasSize(2); // The predicate and the file are prepared so that page indexes will result in non-overlapping row ranges and eliminate the entire first row group // while the second row group still has to be read @@ -201,20 +202,20 @@ void testReadMetadataWithSplitOffset() ParquetReaderOptions.defaultOptions()); // Read both columns, 1 row group - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); List columnBlocks = parquetMetadata.getBlocks(0, 800); assertThat(columnBlocks.size()).isEqualTo(1); assertThat(columnBlocks.getFirst().columns().size()).isEqualTo(2); assertThat(columnBlocks.getFirst().rowCount()).isEqualTo(100); // Read both columns, half row groups - parquetMetadata = MetadataReader.readFooter(dataSource); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); columnBlocks = parquetMetadata.getBlocks(0, 2500); assertThat(columnBlocks.stream().allMatch(block -> block.columns().size() == 2)).isTrue(); assertThat(columnBlocks.stream().mapToLong(BlockMetadata::rowCount).sum()).isEqualTo(300); // Read both columns, all row groups - parquetMetadata = MetadataReader.readFooter(dataSource); + parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); columnBlocks = parquetMetadata.getBlocks(); assertThat(columnBlocks.stream().allMatch(block -> block.columns().size() == 2)).isTrue(); assertThat(columnBlocks.stream().mapToLong(BlockMetadata::rowCount).sum()).isEqualTo(500); @@ -238,7 +239,7 @@ void testMaxFooterReadSize() generateInputPages(types, 10, 50)), ParquetReaderOptions.defaultOptions()); - assertThatThrownBy(() -> MetadataReader.readFooter(dataSource, DataSize.ofBytes(1000))) + assertThatThrownBy(() -> MetadataReader.readFooter(dataSource, DataSize.ofBytes(1000), Optional.empty())) .hasMessageMatching(".* Parquet footer size .* exceeds maximum allowed size .*"); } @@ -248,7 +249,7 @@ private void testReadingOldParquetFiles(File file, List columnNames, Typ ParquetDataSource dataSource = new FileParquetDataSource( file, ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), ImmutableList.of(columnType), columnNames)) { SourcePage page = reader.nextPage(); Iterator expected = expectedValues.iterator(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index ecc23ff12fa1..840d07b4eb28 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.List; +import java.util.Optional; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetTestUtils.createParquetReader; @@ -58,7 +59,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetDataSource dataSource = new FileParquetDataSource( new File(Resources.getResource("time_millis_int32.snappy.parquet").toURI()), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); SourcePage page = reader.nextPage(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java index d6b90ff16e20..ec99cf2a4a43 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestFlatColumnReader.java @@ -137,8 +137,9 @@ private static PageReader getSimplePageReaderMock(ParquetEncoding encoding) OptionalLong.empty(), encoding, encoding, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, Optional.empty(), -1, -1); } private static PageReader getNullOnlyPageReaderMock() @@ -154,7 +155,8 @@ private static PageReader getNullOnlyPageReaderMock() OptionalLong.empty(), RLE, RLE, - PLAIN)); - return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false); + PLAIN, + 0)); + return new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, pages.iterator(), false, false, Optional.empty(), -1, -1); } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index d004ec939a11..b35d81ee2a7f 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -77,6 +77,7 @@ import static io.trino.parquet.ParquetTestUtils.createParquetWriter; import static io.trino.parquet.ParquetTestUtils.generateInputPages; import static io.trino.parquet.ParquetTestUtils.writeParquetFile; +import static io.trino.parquet.metadata.HiddenColumnChunkMetadata.isHiddenColumn; import static io.trino.parquet.writer.ParquetWriterOptions.DEFAULT_BLOOM_FILTER_FPP; import static io.trino.parquet.writer.ParquetWriters.BLOOM_FILTER_EXPECTED_ENTRIES; import static io.trino.spi.type.BigintType.BIGINT; @@ -130,7 +131,7 @@ public void testWrittenPageSize() columnNames, generateInputPages(types, 100, 1000)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks()).hasSize(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -144,6 +145,7 @@ public void testWrittenPageSize() chunkMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -179,7 +181,7 @@ public void testWrittenPageValueCount() columnNames, generateInputPages(types, 100, 1000)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks()).hasSize(1); assertThat(parquetMetadata.getBlocks().get(0).rowCount()).isEqualTo(100 * 1000); @@ -197,6 +199,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columna"}, new PrimitiveType(REQUIRED, INT32, "columna"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -216,6 +219,7 @@ public void testWrittenPageValueCount() columnAMetaData, new ColumnDescriptor(new String[] {"columnb"}, new PrimitiveType(REQUIRED, INT64, "columnb"), 0, 0), null, + Optional.empty(), Optional.empty()); pageReader.readDictionaryPage(); @@ -260,7 +264,7 @@ public void testLargeStringTruncation() ImmutableList.of(new Page(2, blockA, blockB))), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); BlockMetadata blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetadata chunkMetaData = blockMetaData.columns().get(0); @@ -293,7 +297,7 @@ public void testColumnReordering() generateInputPages(types, 100, 100)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(10); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { // Verify that the columns are stored in the same order as the metadata @@ -350,7 +354,7 @@ public void testDictionaryPageOffset() generateInputPages(types, 100, 100)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); assertThat(parquetMetadata.getBlocks().size()).isGreaterThanOrEqualTo(1); for (BlockMetadata blockMetaData : parquetMetadata.getBlocks()) { ColumnChunkMetadata chunkMetaData = getOnlyElement(blockMetaData.columns()); @@ -399,7 +403,7 @@ void testRowGroupOffset() generateInputPages(types, 100, 10)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); List blocks = parquetMetadata.getBlocks(); assertThat(blocks.size()).isGreaterThan(1); @@ -410,7 +414,12 @@ void testRowGroupOffset() RowGroup rowGroup = rowGroups.get(rowGroupIndex); assertThat(rowGroup.isSetFile_offset()).isTrue(); BlockMetadata blockMetadata = blocks.get(rowGroupIndex); - assertThat(blockMetadata.getStartingPos()).isEqualTo(rowGroup.getFile_offset()); + assertThat(blockMetadata.columns().stream() + .filter(column -> !isHiddenColumn(column)) + .findFirst() + .map(ColumnChunkMetadata::getStartingPos) + .orElseThrow()) + .isEqualTo(rowGroup.getFile_offset()); assertThat(blockMetadata.fileRowCountOffset()).isEqualTo(fileRowCountOffset); fileRowCountOffset += rowGroup.getNum_rows(); } @@ -434,7 +443,7 @@ public void testWriteBloomFilters(Type type, List data) generateInputPages(types, 100, data)), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); // Check that bloom filters are right after each other int bloomFilterSize = Integer.highestOneBit(BlockSplitBloomFilter.optimalNumOfBits(BLOOM_FILTER_EXPECTED_ENTRIES, DEFAULT_BLOOM_FILTER_FPP) / 8) << 1; for (BlockMetadata block : parquetMetadata.getBlocks()) { @@ -499,7 +508,7 @@ void testBloomFilterWithDictionaryFallback() .build()), ParquetReaderOptions.defaultOptions()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); BlockMetadata blockMetaData = getOnlyElement(parquetMetadata.getBlocks()); ColumnChunkMetadata chunkMetaData = getOnlyElement(blockMetaData.columns()); assertThat(chunkMetaData.getEncodingStats().hasDictionaryPages()).isTrue(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index ce5543f2bcc2..b1f63d4a45de 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -388,7 +388,7 @@ private Slice writeMergeResult(Slice path, FileDeletion deletion) } TrinoInputFile inputFile = fileSystem.newInputFile(Location.of(path.toStringUtf8())); try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, parquetReaderOptions, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); long rowCount = parquetMetadata.getBlocks().stream().map(BlockMetadata::rowCount).mapToLong(Long::longValue).sum(); RoaringBitmapArray rowsRetained = new RoaringBitmapArray(); rowsRetained.addRange(0, rowCount - 1); @@ -690,6 +690,7 @@ private ConnectorPageSource createParquetPageSource(Location path) new FileFormatDataSourceStats(), ParquetReaderOptions.builder().withBloomFilter(false).build(), Optional.empty(), + Optional.empty(), domainCompactionThreshold, OptionalLong.of(fileSize)); } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index 9e0f75ec1162..f4bff5ccec4d 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -263,6 +263,7 @@ public ConnectorPageSource createPageSource( fileFormatDataSourceStats, options, Optional.empty(), + Optional.empty(), domainCompactionThreshold, OptionalLong.of(split.getFileSize())); @@ -362,7 +363,7 @@ private static PositionDeleteFilter readDeletes( private Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index b192f03727e1..131f603f7c95 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -179,7 +179,7 @@ public DataFileInfo getDataFileInfo() { Location path = rootTableLocation.appendPath(relativeFilePath); FileMetaData fileMetaData = fileWriter.getFileMetadata(); - ParquetMetadata parquetMetadata = new ParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString())); + ParquetMetadata parquetMetadata = new ParquetMetadata(fileMetaData, new ParquetDataSourceId(path.toString()), Optional.empty()); return new DataFileInfo( relativeFilePath, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index f52167042a77..c697c8256542 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -206,6 +206,7 @@ private static ConnectorPageSource createDeltaLakePageSource( fileFormatDataSourceStats, parquetReaderOptions, Optional.empty(), + Optional.empty(), domainCompactionThreshold, OptionalLong.empty()); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index dd62e0eeba44..a8aaeda155ae 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -219,6 +219,7 @@ public CheckpointEntryIterator( stats, parquetReaderOptions, Optional.empty(), + Optional.empty(), domainCompactionThreshold, OptionalLong.of(fileSize)); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index 27db0930313b..189f5d28e831 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -405,7 +405,7 @@ private void testPartitionValuesParsedCheckpoint(ColumnMappingMode columnMapping assertThat(partitionValuesParsedType.getFields().stream().collect(onlyElement()).getName().orElseThrow()).isEqualTo(physicalColumnName); TrinoParquetDataSource dataSource = new TrinoParquetDataSource(new LocalInputFile(checkpoint.toFile()), ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, ImmutableList.of(addEntryType), List.of("add"))) { List actual = new ArrayList<>(); SourcePage page = reader.nextPage(); @@ -483,7 +483,8 @@ private void testOptimizeWithColumnMappingMode(String columnMappingMode) // Verify optimized parquet file contains the expected physical id and name TrinoInputFile inputFile = new LocalInputFile(tableLocation.resolve(addFileEntry.getPath()).toFile()); ParquetMetadata parquetMetadata = MetadataReader.readFooter( - new TrinoParquetDataSource(inputFile, ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats())); + new TrinoParquetDataSource(inputFile, ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats()), + Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); PrimitiveType physicalType = getOnlyElement(fileMetaData.getSchema().getColumns().iterator()).getPrimitiveType(); assertThat(physicalType.getName()).isEqualTo(physicalName); diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 705f3c9ac6c2..d99075c818a8 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -176,6 +176,11 @@ parquet-column + + org.apache.parquet + parquet-common + + org.apache.parquet parquet-format-structures @@ -330,12 +335,6 @@ runtime - - org.apache.parquet - parquet-common - runtime - - org.jetbrains annotations @@ -423,6 +422,13 @@ test + + io.trino + trino-parquet + test-jar + test + + io.trino trino-parser diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java index 46c4affef94a..30c454f9517b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java @@ -14,14 +14,15 @@ package io.trino.plugin.hive; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.Singleton; import com.google.inject.multibindings.Multibinder; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.hive.avro.AvroFileWriterFactory; import io.trino.plugin.hive.avro.AvroPageSourceFactory; +import io.trino.plugin.hive.crypto.ParquetEncryptionModule; import io.trino.plugin.hive.esri.EsriPageSourceFactory; import io.trino.plugin.hive.fs.CachingDirectoryLister; import io.trino.plugin.hive.fs.DirectoryLister; @@ -61,10 +62,10 @@ import static org.weakref.jmx.guice.ExportBinder.newExporter; public class HiveModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + public void setup(Binder binder) { configBinder(binder).bindConfig(HiveConfig.class); configBinder(binder).bindConfig(HiveMetastoreConfig.class); @@ -136,6 +137,7 @@ public void configure(Binder binder) fileWriterFactoryBinder.addBinding().to(ParquetFileWriterFactory.class).in(Scopes.SINGLETON); binder.install(new HiveExecutorModule()); + install(new ParquetEncryptionModule()); } @Provides diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/EnvironmentDecryptionKeyRetriever.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/EnvironmentDecryptionKeyRetriever.java new file mode 100644 index 000000000000..ce5e0e0d50a3 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/EnvironmentDecryptionKeyRetriever.java @@ -0,0 +1,135 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +import static org.apache.parquet.Preconditions.checkArgument; + +/** + * Reads keys from two environment variables. + *
+ *   pme.environment-key-retriever.footer-keys  =  <single‑key> | id1:key1,id2:key2 …
+ *   pme.environment-key-retriever.column-keys  =  <single‑key> | id1:key1,id2:key2 …
+ * 
+ *
    + *
  • If the value contains ‘:’ we treat it as a map (comma‑separated {@code id:key}). + * The {@code id} must match the {@code keyMetadata} supplied by Parquet.
  • + *
  • Otherwise it is a single default key, independent of {@code keyMetadata}.
  • + *
  • Keys are expected to be Base‑64; if decoding fails we fall back to the raw UTF‑8 bytes.
  • + *
+ */ +public final class EnvironmentDecryptionKeyRetriever + implements DecryptionKeyRetriever +{ + private static final String FOOTER_VARIABLE_NAME = "pme.environment-key-retriever.footer-keys"; + private static final String COLUMN_VARIABLE_NAME = "pme.environment-key-retriever.column-keys"; + + private final KeySource footerKeys; + private final KeySource columnKeys; + + @Inject + public EnvironmentDecryptionKeyRetriever() + { + this(parseEnvironmentVariable(FOOTER_VARIABLE_NAME), parseEnvironmentVariable(COLUMN_VARIABLE_NAME)); + } + + @VisibleForTesting + EnvironmentDecryptionKeyRetriever(String footerValue, String columnValue) + { + this(parseValue(footerValue, FOOTER_VARIABLE_NAME), parseValue(columnValue, COLUMN_VARIABLE_NAME)); + } + + private EnvironmentDecryptionKeyRetriever(KeySource footerKeys, KeySource columnKeys) + { + this.footerKeys = footerKeys; + this.columnKeys = columnKeys; + } + + @Override + public Optional getColumnKey(ColumnPath columnPath, Optional keyMetadata) + { + return columnKeys.resolve(keyMetadata); + } + + @Override + public Optional getFooterKey(Optional keyMetadata) + { + return footerKeys.resolve(keyMetadata); + } + + private static KeySource parseEnvironmentVariable(String variable) + { + return parseValue(System.getenv(variable), variable); + } + + private static KeySource parseValue(String value, String variable) + { + if (value == null || value.isBlank()) { + return KeySource.empty(); + } + if (value.contains(":")) { + // map mode + ImmutableMap.Builder map = ImmutableMap.builder(); + for (String entry : value.split("\\s*,\\s*")) { + checkArgument(!entry.isBlank(), "Empty entry in %s", variable); + if (entry.isBlank()) { + continue; + } + String[] parts = entry.split(":", 2); + checkArgument(parts.length == 2, "Malformed entry in %s: %s", variable, entry); + map.put(ByteBuffer.wrap(parts[0].getBytes(StandardCharsets.UTF_8)), decodeKey(parts[1])); + } + return new KeySource(Optional.empty(), map.buildOrThrow()); + } + // single key mode + return new KeySource(Optional.of(decodeKey(value)), ImmutableMap.of()); + } + + private static byte[] decodeKey(String token) + { + return Base64.getDecoder().decode(token); + } + + /** + * container for either a single default key or a map keyed by key‑metadata id + */ + private record KeySource(Optional singleKey, Map keyedKeys) + { + static KeySource empty() + { + return new KeySource(Optional.empty(), Collections.emptyMap()); + } + + Optional resolve(Optional keyMetadata) + { + if (singleKey.isPresent()) { + // no map → keyMetadata irrelevant + return singleKey; + } + return keyMetadata.map(bytes -> keyedKeys.get(ByteBuffer.wrap(bytes))); + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionConfig.java new file mode 100644 index 000000000000..5126d8daf317 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionConfig.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +import java.util.Optional; + +public class ParquetEncryptionConfig +{ + private boolean environmentKeyRetrieverEnabled; + private Optional aadPrefix = Optional.empty(); + private boolean checkFooterIntegrity = true; + + @Config("pme.environment-key-retriever.enabled") + @ConfigDescription("Enable the key retriever that retrieves keys from the environment variable") + public ParquetEncryptionConfig setEnvironmentKeyRetrieverEnabled(boolean enabled) + { + this.environmentKeyRetrieverEnabled = enabled; + return this; + } + + public boolean isEnvironmentKeyRetrieverEnabled() + { + return environmentKeyRetrieverEnabled; + } + + @Config("pme.aad-prefix") + @ConfigDescription("AAD prefix used to decode Parquet files") + public ParquetEncryptionConfig setAadPrefix(String prefix) + { + this.aadPrefix = Optional.ofNullable(prefix); + return this; + } + + public Optional getAadPrefix() + { + return aadPrefix; + } + + @Config("pme.check-footer-integrity") + @ConfigDescription("Validate signature for plaintext footer files") + public ParquetEncryptionConfig setCheckFooterIntegrity(boolean checkFooterIntegrity) + { + this.checkFooterIntegrity = checkFooterIntegrity; + return this; + } + + public boolean isCheckFooterIntegrity() + { + return checkFooterIntegrity; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionModule.java new file mode 100644 index 000000000000..c16bad2c38c5 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/crypto/ParquetEncryptionModule.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.Singleton; +import com.google.inject.multibindings.Multibinder; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.hadoop.metadata.ColumnPath; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +public class ParquetEncryptionModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + ParquetEncryptionConfig config = buildConfigObject(ParquetEncryptionConfig.class); + Multibinder retrieverBinder = + Multibinder.newSetBinder(binder, DecryptionKeyRetriever.class); + if (config.isEnvironmentKeyRetrieverEnabled()) { + retrieverBinder.addBinding().to(EnvironmentDecryptionKeyRetriever.class).in(Scopes.SINGLETON); + } + } + + @Provides + @Singleton + public Optional fileDecryptionProperties( + ParquetEncryptionConfig config, + Set retrievers) + { + if (retrievers.isEmpty()) { + return Optional.empty(); + } + + DecryptionKeyRetriever aggregate = new CompositeDecryptionKeyRetriever(List.copyOf(retrievers)); + + FileDecryptionProperties.Builder builder = FileDecryptionProperties.builder() + .withKeyRetriever(aggregate) + .withCheckFooterIntegrity(config.isCheckFooterIntegrity()); + + config.getAadPrefix() + .map(string -> string.getBytes(StandardCharsets.UTF_8)) + .ifPresent(builder::withAadPrefix); + + return Optional.of(builder.build()); + } + + private static class CompositeDecryptionKeyRetriever + implements DecryptionKeyRetriever + { + private final List delegates; + + CompositeDecryptionKeyRetriever(List delegates) + { + this.delegates = List.copyOf(delegates); + } + + @Override + public Optional getColumnKey(ColumnPath path, Optional meta) + { + return delegates.stream() + .map(delegate -> delegate.getColumnKey(path, meta)) + .filter(Optional::isPresent) + .findFirst() + .orElse(Optional.empty()); + } + + @Override + public Optional getFooterKey(Optional meta) + { + return delegates.stream() + .map(delegate -> delegate.getFooterKey(meta)) + .filter(Optional::isPresent) + .findFirst() + .orElse(Optional.empty()); + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index a8df03061c89..33c144f2e12b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -33,6 +33,7 @@ import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.crypto.FileDecryptionProperties; import io.trino.parquet.metadata.FileMetadata; import io.trino.parquet.metadata.ParquetMetadata; import io.trino.parquet.predicate.TupleDomainParquetPredicate; @@ -134,6 +135,7 @@ public class ParquetPageSourceFactory private final TrinoFileSystemFactory fileSystemFactory; private final FileFormatDataSourceStats stats; + private final Optional fileDecryptionProperties; private final ParquetReaderOptions options; private final DateTimeZone timeZone; private final int domainCompactionThreshold; @@ -142,11 +144,13 @@ public class ParquetPageSourceFactory public ParquetPageSourceFactory( TrinoFileSystemFactory fileSystemFactory, FileFormatDataSourceStats stats, + Optional fileDecryptionProperties, ParquetReaderConfig config, HiveConfig hiveConfig) { this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); this.stats = requireNonNull(stats, "stats is null"); + this.fileDecryptionProperties = requireNonNull(fileDecryptionProperties, "fileDecryptionProperties is null"); options = config.toParquetReaderOptions(); timeZone = hiveConfig.getParquetDateTimeZone(); domainCompactionThreshold = hiveConfig.getDomainCompactionThreshold(); @@ -201,6 +205,7 @@ public Optional createPageSource( .withVectorizedDecodingEnabled(isParquetVectorizedDecodingEnabled(session)) .build(), Optional.empty(), + fileDecryptionProperties, domainCompactionThreshold, OptionalLong.of(estimatedFileSize))); } @@ -219,6 +224,7 @@ public static ConnectorPageSource createPageSource( FileFormatDataSourceStats stats, ParquetReaderOptions options, Optional parquetWriteValidation, + Optional fileDecryptionProperties, int domainCompactionThreshold, OptionalLong estimatedFileSize) { @@ -230,7 +236,7 @@ public static ConnectorPageSource createPageSource( AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, estimatedFileSize, options, memoryContext, stats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.of(options.getMaxFooterReadSize()), parquetWriteValidation); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.of(options.getMaxFooterReadSize()), parquetWriteValidation, fileDecryptionProperties); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); fileSchema = fileMetaData.getSchema(); @@ -285,7 +291,8 @@ public static ConnectorPageSource createPageSource( // We avoid using disjuncts of parquetPredicate for page pruning in ParquetReader as currently column indexes // are not present in the Parquet files which are read with disjunct predicates. parquetPredicates.size() == 1 ? Optional.of(parquetPredicates.getFirst()) : Optional.empty(), - parquetWriteValidation); + parquetWriteValidation, + parquetMetadata.getDecryptionContext()); return createParquetPageSource(columns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } catch (Exception e) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java index 2eb36d67a8e5..f319a9932f27 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java @@ -23,6 +23,7 @@ import io.trino.metastore.Database; import io.trino.metastore.HiveMetastore; import io.trino.metastore.HiveMetastoreFactory; +import io.trino.parquet.crypto.DecryptionKeyRetriever; import io.trino.plugin.tpcds.TpcdsPlugin; import io.trino.plugin.tpch.ColumnNaming; import io.trino.plugin.tpch.DecimalTypeMapping; @@ -105,6 +106,7 @@ public static class Builder> private boolean createTpchSchemas = true; private ColumnNaming tpchColumnNaming = SIMPLIFIED; private DecimalTypeMapping tpchDecimalTypeMapping = DOUBLE; + private Optional decryptionKeyRetriever = Optional.empty(); protected Builder() { @@ -196,6 +198,13 @@ public SELF setTpchDecimalTypeMapping(DecimalTypeMapping tpchDecimalTypeMapping) return self(); } + @CanIgnoreReturnValue + public SELF setDecryptionKeyRetriever(DecryptionKeyRetriever decryptionKeyRetriever) + { + this.decryptionKeyRetriever = Optional.of(requireNonNull(decryptionKeyRetriever, "decryptionKeyRetriever is null")); + return self(); + } + @Override public DistributedQueryRunner build() throws Exception @@ -227,7 +236,7 @@ public DistributedQueryRunner build() hiveProperties.put("fs.hadoop.enabled", "true"); } - queryRunner.installPlugin(new TestingHivePlugin(dataDir, metastore)); + queryRunner.installPlugin(new TestingHivePlugin(dataDir, metastore, decryptionKeyRetriever)); Map hiveProperties = new HashMap<>(); if (!skipTimezoneSetup) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveTestUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveTestUtils.java index 85a5a62edde9..0b551e13a6a0 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveTestUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveTestUtils.java @@ -178,7 +178,7 @@ public static Set getDefaultHivePageSourceFactories(Trino .add(new AvroPageSourceFactory(fileSystemFactory)) .add(new RcFilePageSourceFactory(fileSystemFactory, hiveConfig)) .add(new OrcPageSourceFactory(new OrcReaderConfig(), fileSystemFactory, stats, hiveConfig)) - .add(new ParquetPageSourceFactory(fileSystemFactory, stats, new ParquetReaderConfig(), hiveConfig)) + .add(new ParquetPageSourceFactory(fileSystemFactory, stats, Optional.empty(), new ParquetReaderConfig(), hiveConfig)) .add(new ProtobufSequenceFilePageSourceFactory(fileSystemFactory, hiveConfig)) .build(); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java index 255a2c728a6d..8f79bf1acfef 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java @@ -20,6 +20,7 @@ import io.airlift.units.Duration; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.nio.file.Path; import java.util.Map; import java.util.TimeZone; @@ -128,6 +129,7 @@ public void testDefaults() @Test public void testExplicitPropertyMappings() + throws IOException { Map properties = ImmutableMap.builder() .put("hive.single-statement-writes", "true") diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index cc7e10059471..0960d6b29fcd 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -554,7 +554,7 @@ public void testParquetPageSource(int rowCount, long fileSizePadding) .withSession(PARQUET_SESSION) .withRowsCount(rowCount) .withFileSizePadding(fileSizePadding) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } @Test(dataProvider = "validRowAndFileSizePadding") @@ -568,7 +568,7 @@ public void testParquetPageSourceGzip(int rowCount, long fileSizePadding) .withCompressionCodec(HiveCompressionCodec.GZIP) .withFileSizePadding(fileSizePadding) .withRowsCount(rowCount) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } @Test(dataProvider = "rowCount") @@ -583,7 +583,7 @@ public void testParquetWriter(int rowCount) .withColumns(testColumns) .withRowsCount(rowCount) .withFileWriterFactory(fileSystemFactory -> new ParquetFileWriterFactory(fileSystemFactory, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig(), STATS)) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } @Test(dataProvider = "rowCount") @@ -601,7 +601,7 @@ public void testParquetPageSourceSchemaEvolution(int rowCount) .withReadColumns(readColumns) .withSession(PARQUET_SESSION) .withRowsCount(rowCount) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); // test the name-based access readColumns = writeColumns.reversed(); @@ -609,7 +609,7 @@ public void testParquetPageSourceSchemaEvolution(int rowCount) .withWriteColumns(writeColumns) .withReadColumns(readColumns) .withSession(PARQUET_SESSION_USE_NAME) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } @Test(dataProvider = "rowCount") @@ -627,7 +627,7 @@ public void testParquetCaseSensitivity(int rowCount) .withSession(getHiveSession(createParquetHiveConfig(true), new ParquetWriterConfig().setValidationPercentage(0))) .withRowsCount(rowCount) .withFileWriterFactory(fileSystemFactory -> new ParquetFileWriterFactory(fileSystemFactory, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig(), STATS)) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } private static List getTestColumnsSupportedByParquet() @@ -670,7 +670,7 @@ public void testTruncateVarcharColumn() .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) .withSession(PARQUET_SESSION) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); assertThatFileFormat(AVRO) .withWriteColumns(ImmutableList.of(writeColumn)) @@ -736,14 +736,14 @@ public void testParquetProjectedColumns(int rowCount) .withReadColumns(readColumns) .withRowsCount(rowCount) .withSession(PARQUET_SESSION) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); assertThatFileFormat(PARQUET) .withWriteColumns(writeColumns) .withReadColumns(readColumns) .withRowsCount(rowCount) .withSession(PARQUET_SESSION_USE_NAME) - .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig())); + .isReadableByPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig())); } @Test(dataProvider = "rowCount") @@ -944,7 +944,7 @@ public void testFailForLongVarcharPartitionColumn() assertThatFileFormat(PARQUET) .withColumns(columns) .withSession(PARQUET_SESSION) - .isFailingForPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, new ParquetReaderConfig(), new HiveConfig()), expectedErrorCode, expectedMessage); + .isFailingForPageSource(fileSystemFactory -> new ParquetPageSourceFactory(fileSystemFactory, STATS, Optional.empty(), new ParquetReaderConfig(), new HiveConfig()), expectedErrorCode, expectedMessage); } private static void testPageSourceFactory( diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHiveConnectorFactory.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHiveConnectorFactory.java index c0e03b111e10..e26a56ed33e7 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHiveConnectorFactory.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHiveConnectorFactory.java @@ -15,9 +15,11 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Module; +import com.google.inject.multibindings.Multibinder; import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.local.LocalFileSystemFactory; import io.trino.metastore.HiveMetastore; +import io.trino.parquet.crypto.DecryptionKeyRetriever; import io.trino.plugin.hive.metastore.file.FileHiveMetastoreConfig; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorContext; @@ -40,11 +42,11 @@ public class TestingHiveConnectorFactory public TestingHiveConnectorFactory(Path localFileSystemRootPath) { - this(localFileSystemRootPath, Optional.empty()); + this(localFileSystemRootPath, Optional.empty(), Optional.empty()); } @Deprecated - public TestingHiveConnectorFactory(Path localFileSystemRootPath, Optional metastore) + public TestingHiveConnectorFactory(Path localFileSystemRootPath, Optional metastore, Optional decryptionKeyRetriever) { this.metastore = requireNonNull(metastore, "metastore is null"); @@ -53,6 +55,12 @@ public TestingHiveConnectorFactory(Path localFileSystemRootPath, Optional config.setCatalogDirectory("local:///")); + + decryptionKeyRetriever.ifPresent(retriever -> { + Multibinder retrieverBinder = + Multibinder.newSetBinder(binder, DecryptionKeyRetriever.class); + retrieverBinder.addBinding().toInstance(retriever); + }); }; } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHivePlugin.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHivePlugin.java index 41f492a99405..b472f9355893 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHivePlugin.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingHivePlugin.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.metastore.HiveMetastore; +import io.trino.parquet.crypto.DecryptionKeyRetriever; import io.trino.spi.Plugin; import io.trino.spi.connector.ConnectorFactory; @@ -28,28 +29,30 @@ public class TestingHivePlugin { private final Path localFileSystemRootPath; private final Optional metastore; + private final Optional decryptionKeyRetriever; public TestingHivePlugin(Path localFileSystemRootPath) { - this(localFileSystemRootPath, Optional.empty()); + this(localFileSystemRootPath, Optional.empty(), Optional.empty()); } @Deprecated public TestingHivePlugin(Path localFileSystemRootPath, HiveMetastore metastore) { - this(localFileSystemRootPath, Optional.of(metastore)); + this(localFileSystemRootPath, Optional.of(metastore), Optional.empty()); } @Deprecated - public TestingHivePlugin(Path localFileSystemRootPath, Optional metastore) + public TestingHivePlugin(Path localFileSystemRootPath, Optional metastore, Optional decryptionKeyRetriever) { this.localFileSystemRootPath = requireNonNull(localFileSystemRootPath, "localFileSystemRootPath is null"); this.metastore = requireNonNull(metastore, "metastore is null"); + this.decryptionKeyRetriever = requireNonNull(decryptionKeyRetriever, "decryptionKeyRetriever is null"); } @Override public Iterable getConnectorFactories() { - return ImmutableList.of(new TestingHiveConnectorFactory(localFileSystemRootPath, metastore)); + return ImmutableList.of(new TestingHiveConnectorFactory(localFileSystemRootPath, metastore, decryptionKeyRetriever)); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestEnvironmentDecryptionKeyRetriever.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestEnvironmentDecryptionKeyRetriever.java new file mode 100644 index 000000000000..66537a8e478c --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestEnvironmentDecryptionKeyRetriever.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.testng.annotations.Test; + +import java.util.Base64; +import java.util.Optional; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; + +public final class TestEnvironmentDecryptionKeyRetriever +{ + private static final ColumnPath AGE = ColumnPath.fromDotString("age"); + + @Test + public void defaultEmpty() + { + EnvironmentDecryptionKeyRetriever retriever = new EnvironmentDecryptionKeyRetriever(null, null); + + assertThat(retriever.getFooterKey(Optional.empty())).isEmpty(); + assertThat(retriever.getColumnKey(AGE, Optional.empty())).isEmpty(); + } + + @Test + public void singleKeyMode() + { + EnvironmentDecryptionKeyRetriever retriever = new EnvironmentDecryptionKeyRetriever(b64("foot"), b64("colKey")); + + assertThat(retriever.getFooterKey(Optional.of("ignored".getBytes(UTF_8)))).contains("foot".getBytes(UTF_8)); + assertThat(retriever.getColumnKey(AGE, Optional.empty())).contains("colKey".getBytes(UTF_8)); + } + + @Test + public void mapModeUsesMetadata() + { + // footer: id1→k1 , id2→k2 + String footerValue = String.join(",", + "id1:" + b64("k1"), "id2:" + b64("k2")); + // column: meta→ageKey + String columnValue = "meta:" + b64("ageKey"); + + EnvironmentDecryptionKeyRetriever retriever = new EnvironmentDecryptionKeyRetriever(footerValue, columnValue); + + assertThat(retriever.getFooterKey(Optional.of("id2".getBytes(UTF_8)))).contains("k2".getBytes(UTF_8)); + assertThat(retriever.getColumnKey(AGE, Optional.of("meta".getBytes(UTF_8)))).contains("ageKey".getBytes(UTF_8)); + + // unknown metadata → empty + assertThat(retriever.getFooterKey(Optional.of("zzz".getBytes(UTF_8)))).isEmpty(); + assertThat(retriever.getColumnKey(AGE, Optional.empty())).isEmpty(); + } + + private static String b64(String string) + { + return Base64.getEncoder().encodeToString(string.getBytes(UTF_8)); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestHiveParquetEncryption.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestHiveParquetEncryption.java new file mode 100644 index 000000000000..a90e652285eb --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestHiveParquetEncryption.java @@ -0,0 +1,334 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.Location; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.crypto.DecryptionKeyRetriever; +import io.trino.parquet.crypto.FileDecryptionProperties; +import io.trino.parquet.metadata.ColumnChunkMetadata; +import io.trino.parquet.metadata.ParquetMetadata; +import io.trino.parquet.reader.FileParquetDataSource; +import io.trino.parquet.reader.MetadataReader; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import org.apache.hadoop.conf.Configuration; +import org.apache.parquet.crypto.ColumnEncryptionProperties; +import org.apache.parquet.crypto.FileEncryptionProperties; +import org.apache.parquet.crypto.ParquetCipher; +import org.apache.parquet.example.data.Group; +import org.apache.parquet.example.data.simple.SimpleGroupFactory; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.example.ExampleParquetWriter; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.schema.MessageTypeParser; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; + +import java.nio.file.Files; +import java.util.Map; +import java.util.Optional; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static org.apache.parquet.hadoop.ParquetFileWriter.Mode.OVERWRITE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +/** + * End‑to‑end PME flow: Parquet writer → Hive connector → Trino query. + */ +// ExampleParquetWriter is not thread-safe +@TestInstance(PER_CLASS) +@Execution(ExecutionMode.SAME_THREAD) +public class TestHiveParquetEncryption + extends AbstractTestQueryFramework +{ + private static final byte[] FOOTER_KEY = "footKeyIs16Byte?".getBytes(UTF_8); + + // Keys per column (different on purpose) + private static final byte[] COLUMN_KEY_AGE = "colKeyIs16ByteA?".getBytes(UTF_8); + private static final byte[] COLUMN_KEY_ID = "colKeyIs16ByteB?".getBytes(UTF_8); + + private static final ColumnPath AGE_PATH = ColumnPath.fromDotString("age"); + private static final ColumnPath ID_PATH = ColumnPath.fromDotString("id"); + + /** + * kept so we can reference the warehouse path later + */ + private java.nio.file.Path warehouseDir; + + @Override + protected DistributedQueryRunner createQueryRunner() + throws Exception + { + warehouseDir = Files.createTempDirectory("pme_hive"); + + Map properties = ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", warehouseDir.toUri().toString()) + .put("pme.environment-key-retriever.enabled", "false") + .buildOrThrow(); + + // Bind retriever that knows ONLY the footer key + age column key + return HiveQueryRunner.builder() + .setHiveProperties(properties) + .setDecryptionKeyRetriever(new TestingParquetEncryptionModule(FOOTER_KEY, Optional.of(COLUMN_KEY_AGE), Optional.empty())) + .build(); + } + + @Test + public void testEncryptedParquetRead() + throws Exception + { + // 1) write encrypted file inside warehouse + java.nio.file.Path dataDir = Files.createDirectories(warehouseDir.resolve("pme_data")); + java.nio.file.Path parquetFile = dataDir.resolve("data.parquet"); + writeEncryptedFile(parquetFile); + + // 2) create external table + String location = Location.of(String.valueOf(dataDir.toUri())).toString(); + assertUpdate(""" + CREATE TABLE enc_age(age INT) + WITH (external_location = '%s', format = 'PARQUET') + """.formatted(location)); + + // 3) verify results + MaterializedResult result = computeActual("SELECT COUNT(*), MIN(age), MAX(age) FROM enc_age"); + assertThat(result.getMaterializedRows().get(0).getField(0)).isEqualTo(100L); // count + assertThat(result.getMaterializedRows().get(0).getField(1)).isEqualTo(0); // min + assertThat(result.getMaterializedRows().get(0).getField(2)).isEqualTo(99); // max + } + + @Test + public void testTwoColumnFileOnlyAgeKeyProvided() + throws Exception + { + // 1) write a two-column file (both columns encrypted with different keys) + java.nio.file.Path dataDir = Files.createDirectories(warehouseDir.resolve("pme_two_cols")); + java.nio.file.Path parquetFile = dataDir.resolve("data.parquet"); + writeTwoColumnEncryptedFile(parquetFile); + + // 2) create external table with both columns + String location = Location.of(String.valueOf(dataDir.toUri())).toString(); + assertUpdate(""" + CREATE TABLE enc_two(id INT, age INT) + WITH (external_location = '%s', format = 'PARQUET') + """.formatted(location)); + + // 3) Selecting ONLY the accessible column (age) must succeed + MaterializedResult ok = computeActual("SELECT COUNT(*), MIN(age), MAX(age) FROM enc_two"); + assertThat(ok.getMaterializedRows().get(0).getField(0)).isEqualTo(100L); + assertThat(ok.getMaterializedRows().get(0).getField(1)).isEqualTo(0); + assertThat(ok.getMaterializedRows().get(0).getField(2)).isEqualTo(99); + + // 4) Selecting the inaccessible column (id) must fail (no column key) + // We match on a broad error text to avoid coupling to a specific message. + assertQueryFails("SELECT MIN(id) FROM enc_two", "(?s).*User does not have access to column.*"); + } + + @Test + public void testEncryptedDictionaryPruningTwoColumns() + throws Exception + { + // 1) write an encrypted, dictionary-encoded two-column file + // Values are 0..10 except the “missing” ones; RG min/max is [0,10] so min/max alone can’t prune. + java.nio.file.Path dataDir = Files.createDirectories(warehouseDir.resolve("pme_dict2_enc")); + java.nio.file.Path parquetFile = dataDir.resolve("data.parquet"); + int missingAge = 7; + int missingId = 3; + writeTwoColumnEncryptedDictionaryFile(parquetFile, missingAge, missingId); + + // 2) create external table + String location = Location.of(String.valueOf(dataDir.toUri())).toString(); + assertUpdate(""" + CREATE TABLE enc_dict2(id INT, age INT) + WITH (external_location = '%s', format = 'PARQUET') + """.formatted(location)); + + // 3) Predicate on the accessible column (age) – dictionary (encrypted) gets read & prunes to zero + assertThat(computeActual("SELECT count(*) FROM enc_dict2 WHERE age = " + missingAge).getOnlyValue()) + .isEqualTo(0L); + + // sanity: present value on age returns rows + assertThat(((Number) computeActual("SELECT count(*) FROM enc_dict2 WHERE age = 5").getOnlyValue()).longValue()) + .isGreaterThan(0L); + + // 4) Predicate on the inaccessible column (id) – should fail (no column key) + assertQueryFails("SELECT count(*) FROM enc_dict2 WHERE id = " + missingId, "(?s).*access.*column.*id.*"); + + assertQuerySucceeds("DROP TABLE enc_dict2"); + } + + // ───────────────────────── writer ───────────────────────── + private static void writeEncryptedFile(java.nio.file.Path path) + throws Exception + { + var schema = MessageTypeParser.parseMessageType( + "message doc { required int32 age; }"); + + // This test purposely reuses one demo key. + // With Parquet Modular Encryption (AES-GCM/CTR), reusing a key for lots of data + // weakens security. + ColumnEncryptionProperties columnProperties = ColumnEncryptionProperties.builder(AGE_PATH) + .withKey(COLUMN_KEY_AGE) + .build(); + + FileEncryptionProperties encodingProperties = FileEncryptionProperties.builder(FOOTER_KEY) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(Map.of(AGE_PATH, columnProperties)) + .build(); + + try (ParquetWriter writer = ExampleParquetWriter.builder(new org.apache.hadoop.fs.Path(path.toString())) + .withType(schema) + .withConf(new Configuration()) + .withEncryption(encodingProperties) + .withWriteMode(OVERWRITE) + .build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(schema); + for (int i = 0; i < 100; i++) { + writer.write(factory.newGroup().append("age", i)); + } + } + } + + private static void writeTwoColumnEncryptedFile(java.nio.file.Path path) + throws Exception + { + var schema = MessageTypeParser.parseMessageType("message doc { required int32 age; required int32 id; }"); + + ColumnEncryptionProperties idProperties = ColumnEncryptionProperties.builder(ID_PATH).withKey(COLUMN_KEY_ID).build(); + ColumnEncryptionProperties ageProperties = ColumnEncryptionProperties.builder(AGE_PATH).withKey(COLUMN_KEY_AGE).build(); + + FileEncryptionProperties encodingProperties = FileEncryptionProperties.builder(FOOTER_KEY) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(Map.of(AGE_PATH, ageProperties, ID_PATH, idProperties)) + .build(); + + try (ParquetWriter writer = ExampleParquetWriter.builder(new org.apache.hadoop.fs.Path(path.toString())) + .withType(schema) + .withConf(new Configuration()) + .withEncryption(encodingProperties) + .withWriteMode(OVERWRITE) + .build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(schema); + for (int i = 0; i < 100; i++) { + writer.write(factory.newGroup().append("id", 100 - i).append("age", i)); + } + } + } + + /** + * Two-column writer: + * - both columns encrypted (different keys), + * - tiny page size to encourage dictionary encoding, + * - each column skips one value in 0..10 so dictionary-based pruning can eliminate the row-group. + * Reader in this test only has the AGE key (ID key is absent). + */ + private static void writeTwoColumnEncryptedDictionaryFile(java.nio.file.Path path, int missingAge, int missingId) + throws Exception + { + var schema = MessageTypeParser.parseMessageType("message doc { required int32 age; required int32 id; }"); + + ColumnEncryptionProperties idProperties = ColumnEncryptionProperties.builder(ID_PATH).withKey(COLUMN_KEY_ID).build(); + ColumnEncryptionProperties ageProperties = ColumnEncryptionProperties.builder(AGE_PATH).withKey(COLUMN_KEY_AGE).build(); + + FileEncryptionProperties encodingProperties = FileEncryptionProperties.builder(FOOTER_KEY) + .withAlgorithm(ParquetCipher.AES_GCM_CTR_V1) + .withEncryptedColumns(Map.of(AGE_PATH, ageProperties, ID_PATH, idProperties)) + .build(); + + try (ParquetWriter writer = ExampleParquetWriter.builder(new org.apache.hadoop.fs.Path(path.toString())) + .withType(schema) + .withConf(new Configuration()) + .withEncryption(encodingProperties) + .withWriteMode(OVERWRITE) + .withPageSize(1024) // small pages -> dictionary likely + .build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(schema); + for (int i = 0; i < 5000; i++) { + int id = i % 11; + if (id == missingId) { + id = (id + 1) % 11; // skip one value for id + } + int age = i % 11; + if (age == missingAge) { + age = (age + 1) % 11; // skip one value for age + } + writer.write(factory.newGroup().append("id", id).append("age", age)); + } + } + + // Verify both columns actually have dictionary pages + try (ParquetDataSource source = new FileParquetDataSource(path.toFile(), ParquetReaderOptions.defaultOptions())) { + FileDecryptionProperties dec = FileDecryptionProperties.builder() + .withKeyRetriever(new TestHiveParquetEncryption.TestingParquetEncryptionModule( + FOOTER_KEY, Optional.of(COLUMN_KEY_AGE), Optional.of(COLUMN_KEY_ID))) + .build(); + ParquetMetadata metadata = MetadataReader.readFooter(source, Optional.empty(), Optional.empty(), Optional.of(dec)); + + ColumnChunkMetadata ageChunk = metadata.getBlocks().getFirst().columns().stream() + .filter(column -> column.getPath().equals(AGE_PATH)) + .findFirst().orElseThrow(); + ColumnChunkMetadata idChunk = metadata.getBlocks().getFirst().columns().stream() + .filter(column -> column.getPath().equals(ID_PATH)) + .findFirst().orElseThrow(); + + assertThat(ageChunk.getDictionaryPageOffset()).isGreaterThan(0); + assertThat(idChunk.getDictionaryPageOffset()).isGreaterThan(0); + assertThat(ageChunk.getEncodings()).anyMatch(org.apache.parquet.column.Encoding::usesDictionary); + assertThat(idChunk.getEncodings()).anyMatch(org.apache.parquet.column.Encoding::usesDictionary); + } + } + + public static class TestingParquetEncryptionModule + implements DecryptionKeyRetriever + { + private final byte[] footerKey; + private final Optional ageKey; + private final Optional idKey; + + public TestingParquetEncryptionModule(byte[] footerKey, Optional ageKey, Optional idKey) + { + this.footerKey = requireNonNull(footerKey, "footerKey is null"); + this.ageKey = requireNonNull(ageKey, "ageKey is null"); + this.idKey = requireNonNull(idKey, "idKey is null"); + } + + @Override + public Optional getColumnKey(ColumnPath columnPath, Optional keyMetadata) + { + String path = columnPath.toDotString(); + if ("age".equals(path)) { + return ageKey; + } + if ("id".equals(path)) { + return idKey; + } + return Optional.empty(); + } + + @Override + public Optional getFooterKey(Optional keyMetadata) + { + return Optional.of(footerKey); + } + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestParquetEncryptionConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestParquetEncryptionConfig.java new file mode 100644 index 000000000000..641863e261d7 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/crypto/TestParquetEncryptionConfig.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.crypto; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestParquetEncryptionConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(ParquetEncryptionConfig.class) + .setEnvironmentKeyRetrieverEnabled(false) + .setAadPrefix(null) + .setCheckFooterIntegrity(true)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("pme.environment-key-retriever.enabled", "true") + .put("pme.aad-prefix", "tenant‑42") + .put("pme.check-footer-integrity", "false") + .buildOrThrow(); + + ParquetEncryptionConfig expected = new ParquetEncryptionConfig() + .setEnvironmentKeyRetrieverEnabled(true) + .setAadPrefix("tenant‑42") + .setCheckFooterIntegrity(false); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java index d93a46589c29..e4ad205c98ac 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java @@ -80,6 +80,7 @@ private static ConnectorPageSource createPageSource(ConnectorSession session, Fi HivePageSourceFactory hivePageSourceFactory = new ParquetPageSourceFactory( fileSystemFactory, new FileFormatDataSourceStats(), + Optional.empty(), new ParquetReaderConfig(), hiveConfig); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java index 0f5795aa03c5..fe049c0acfd0 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestBloomFilterStore.java @@ -308,10 +308,10 @@ private static BloomFilterStore generateBloomFilterStore(ParquetTester.TempFile TrinoInputFile inputFile = new LocalInputFile(tempFile.getFile()); TrinoParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats()); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ColumnChunkMetadata columnChunkMetaData = getOnlyElement(getOnlyElement(parquetMetadata.getBlocks()).columns()); - return new BloomFilterStore(dataSource, getOnlyElement(parquetMetadata.getBlocks()), Set.of(columnChunkMetaData.getPath())); + return new BloomFilterStore(dataSource, getOnlyElement(parquetMetadata.getBlocks()), Set.of(columnChunkMetaData.getPath()), Optional.empty()); } private static class BloomFilterTypeTestCase diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java index 532b0568fdd2..8851add79312 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java @@ -230,7 +230,7 @@ private static ConnectorPageSource createPageSource( try { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); dataSource = createDataSource(inputFile, OptionalLong.of(hudiSplit.fileSize()), options, memoryContext, dataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); @@ -271,7 +271,8 @@ private static ConnectorPageSource createPageSource( options, exception -> handleException(dataSourceId, exception), Optional.of(parquetPredicate), - Optional.empty()); + Optional.empty(), + parquetMetadata.getDecryptionContext()); return createParquetPageSource(columns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } catch (IOException | RuntimeException e) { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index aad2add5ec40..c256d6c2c4e8 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -916,7 +916,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ParquetDataSource dataSource = null; try { dataSource = createDataSource(inputFile, OptionalLong.of(fileSize), options, memoryContext, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize(), Optional.empty()); FileMetadata fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); if (nameMapping.isPresent() && !ParquetSchemaUtil.hasIds(fileSchema)) { @@ -1027,6 +1027,7 @@ else if (!parquetIdToFieldName.containsKey(column.getBaseColumn().getId())) { options, exception -> handleException(dataSourceId, exception), Optional.empty(), + Optional.empty(), Optional.empty()); ConnectorPageSource pageSource = new ParquetPageSource(parquetReader); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java index 1c7ba6ac43e3..b947b6cb179c 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java @@ -13,9 +13,11 @@ */ package io.trino.plugin.iceberg; +import com.google.common.collect.ImmutableList; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoOutputFile; import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ParquetMetadata; import io.trino.parquet.writer.ParquetWriterOptions; import io.trino.plugin.hive.parquet.ParquetFileWriter; @@ -29,6 +31,8 @@ import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -36,7 +40,6 @@ import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_FILESYSTEM_ERROR; import static io.trino.plugin.iceberg.util.ParquetUtil.footerMetrics; -import static io.trino.plugin.iceberg.util.ParquetUtil.getSplitOffsets; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -83,7 +86,7 @@ public FileMetrics getFileMetrics() { ParquetMetadata parquetMetadata; try { - parquetMetadata = new ParquetMetadata(parquetFileWriter.getFileMetadata(), new ParquetDataSourceId(location.toString())); + parquetMetadata = new ParquetMetadata(parquetFileWriter.getFileMetadata(), new ParquetDataSourceId(location.toString()), Optional.empty()); return new FileMetrics(footerMetrics(parquetMetadata, Stream.empty(), metricsConfig), Optional.of(getSplitOffsets(parquetMetadata))); } catch (IOException | UncheckedIOException e) { @@ -126,4 +129,16 @@ public long getValidationCpuNanos() { return parquetFileWriter.getValidationCpuNanos(); } + + private static List getSplitOffsets(ParquetMetadata metadata) + throws IOException + { + List blocks = metadata.getBlocks(); + List splitOffsets = new ArrayList<>(blocks.size()); + for (BlockMetadata blockMetaData : blocks) { + splitOffsets.add(blockMetaData.columns().getFirst().getStartingPos()); + } + Collections.sort(splitOffsets); + return ImmutableList.copyOf(splitOffsets); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrationUtils.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrationUtils.java index 630a5ca4f867..06e9d93b4bba 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrationUtils.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/MigrationUtils.java @@ -154,7 +154,7 @@ private static Metrics parquetMetrics(TrinoInputFile file, MetricsConfig metrics { ParquetReaderOptions options = ParquetReaderOptions.defaultOptions(); try (ParquetDataSource dataSource = new TrinoParquetDataSource(file, ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats())) { - ParquetMetadata metadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize()); + ParquetMetadata metadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize(), Optional.empty()); return ParquetUtil.footerMetrics(metadata, Stream.empty(), metricsConfig, nameMapping); } catch (IOException e) { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/ParquetUtil.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/ParquetUtil.java index 98f50940b419..d25f3aaf6365 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/ParquetUtil.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/ParquetUtil.java @@ -14,8 +14,6 @@ package io.trino.plugin.iceberg.util; -import com.google.common.collect.ImmutableList; -import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ColumnChunkMetadata; import io.trino.parquet.metadata.ParquetMetadata; @@ -41,13 +39,12 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -70,7 +67,7 @@ public final class ParquetUtil private ParquetUtil() {} public static Metrics footerMetrics(ParquetMetadata metadata, Stream> fieldMetrics, MetricsConfig metricsConfig) - throws ParquetCorruptionException + throws IOException { return footerMetrics(metadata, fieldMetrics, metricsConfig, null); } @@ -80,7 +77,7 @@ public static Metrics footerMetrics( Stream> fieldMetrics, MetricsConfig metricsConfig, NameMapping nameMapping) - throws ParquetCorruptionException + throws IOException { requireNonNull(fieldMetrics, "fieldMetrics should not be null"); @@ -158,18 +155,6 @@ public static Metrics footerMetrics( toBufferMap(fileSchema, upperBounds)); } - public static List getSplitOffsets(ParquetMetadata metadata) - throws ParquetCorruptionException - { - List blocks = metadata.getBlocks(); - List splitOffsets = new ArrayList<>(blocks.size()); - for (BlockMetadata blockMetaData : blocks) { - splitOffsets.add(blockMetaData.getStartingPos()); - } - Collections.sort(splitOffsets); - return ImmutableList.copyOf(splitOffsets); - } - private static void updateFromFieldMetrics( Map> idToFieldMetricsMap, MetricsConfig metricsConfig, diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java index 826a3d01eeb0..63a0f3205633 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java @@ -267,7 +267,7 @@ public static Map getMetadataFileAndUpdatedMillis(TrinoFileSystem public static ParquetMetadata getParquetFileMetadata(TrinoInputFile inputFile) { try (TrinoParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, ParquetReaderOptions.defaultOptions(), new FileFormatDataSourceStats())) { - return MetadataReader.readFooter(dataSource); + return MetadataReader.readFooter(dataSource, Optional.empty()); } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/plugin/trino-lakehouse/src/main/java/io/trino/plugin/lakehouse/LakehouseHiveModule.java b/plugin/trino-lakehouse/src/main/java/io/trino/plugin/lakehouse/LakehouseHiveModule.java index 2892a81d0116..13e97b466d58 100644 --- a/plugin/trino-lakehouse/src/main/java/io/trino/plugin/lakehouse/LakehouseHiveModule.java +++ b/plugin/trino-lakehouse/src/main/java/io/trino/plugin/lakehouse/LakehouseHiveModule.java @@ -40,6 +40,7 @@ import io.trino.plugin.hive.TransactionalMetadataFactory; import io.trino.plugin.hive.avro.AvroFileWriterFactory; import io.trino.plugin.hive.avro.AvroPageSourceFactory; +import io.trino.plugin.hive.crypto.ParquetEncryptionModule; import io.trino.plugin.hive.fs.CachingDirectoryLister; import io.trino.plugin.hive.fs.DirectoryLister; import io.trino.plugin.hive.fs.TransactionScopeCachingDirectoryListerFactory; @@ -136,5 +137,6 @@ protected void setup(Binder binder) fileWriterFactoryBinder.addBinding().to(ParquetFileWriterFactory.class).in(Scopes.SINGLETON); binder.install(new HiveExecutorModule()); + install(new ParquetEncryptionModule()); } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java index bab0f514bb4f..510777c3dbda 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java @@ -104,7 +104,7 @@ private ParquetReader parquetReader(TrinoInputFile inputFile, List { ParquetReaderOptions options = ParquetReaderOptions.defaultOptions(); TrinoParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize()); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, options.getMaxFooterReadSize(), Optional.empty()); MessageType fileSchema = parquetMetadata.getFileMetaData().getSchema(); MessageColumnIO messageColumn = getColumnIO(fileSchema, fileSchema); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); @@ -127,6 +127,7 @@ private ParquetReader parquetReader(TrinoInputFile inputFile, List options, RedshiftParquetPageSource::handleException, Optional.empty(), + Optional.empty(), Optional.empty()); }