diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputFile.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputFile.java index df12d5ebb9ea..daa071b938d2 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputFile.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputFile.java @@ -16,6 +16,7 @@ import io.trino.filesystem.Location; import io.trino.filesystem.TrinoOutputFile; import io.trino.filesystem.encryption.EncryptionKey; +import io.trino.filesystem.s3.S3OutputStream.ByteArrayStreamProvider; import io.trino.memory.context.AggregatedMemoryContext; import software.amazon.awssdk.services.s3.S3Client; @@ -56,9 +57,7 @@ public void createOrOverwrite(byte[] data) location, key, false, - data, - 0, - data.length); + new ByteArrayStreamProvider(data)); } @Override @@ -71,9 +70,7 @@ public void createExclusive(byte[] data) location, key, true, - data, - 0, - data.length); + new ByteArrayStreamProvider(data)); } @Override diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java index 70c58955f91f..7942a8cba979 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java @@ -14,10 +14,13 @@ package io.trino.filesystem.s3; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; import io.trino.filesystem.FileMayHaveAlreadyExistedException; import io.trino.filesystem.encryption.EncryptionKey; import io.trino.memory.context.AggregatedMemoryContext; import io.trino.memory.context.LocalMemoryContext; +import jakarta.annotation.Nullable; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; @@ -33,39 +36,46 @@ import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.InterruptedIOException; import java.io.OutputStream; -import java.nio.ByteBuffer; +import java.io.SequenceInputStream; import java.nio.file.FileAlreadyExistsException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; +import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.filesystem.s3.S3FileSystemConfig.ObjectCannedAcl.getCannedAcl; import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE; import static io.trino.filesystem.s3.S3FileSystemConfig.StorageClassType.toStorageClass; import static io.trino.filesystem.s3.S3SseCUtils.encoded; import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum; import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings; -import static java.lang.Math.clamp; -import static java.lang.Math.max; import static java.lang.Math.min; +import static java.lang.Math.toIntExact; import static java.lang.System.arraycopy; import static java.net.HttpURLConnection.HTTP_PRECON_FAILED; +import static java.util.Collections.enumeration; import static java.util.Objects.checkFromIndexSize; import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNullElse; import static java.util.concurrent.CompletableFuture.supplyAsync; +import static software.amazon.awssdk.core.internal.util.Mimetype.MIMETYPE_OCTET_STREAM; final class S3OutputStream extends OutputStream { + private static final int INITIAL_BUFFER_SIZE = toIntExact(DataSize.of(1, MEGABYTE).toBytes()); + private static final int MAXIMUM_BUFFER_SIZE = toIntExact(DataSize.of(8, MEGABYTE).toBytes()); + private final List parts = new ArrayList<>(); private final LocalMemoryContext memoryContext; private final Executor uploadExecutor; @@ -77,11 +87,11 @@ final class S3OutputStream private final StorageClass storageClass; private final ObjectCannedACL cannedAcl; private final Optional key; + @Nullable // is null on last upload request + private LinkedBuffer buffer; + private int inFlightBytes; private int currentPartNumber; - private byte[] buffer = new byte[0]; - private int bufferSize; - private int initialBufferSize = 64; private boolean closed; private boolean failed; @@ -105,19 +115,18 @@ public S3OutputStream(AggregatedMemoryContext memoryContext, Executor uploadExec this.storageClass = toStorageClass(context.storageClass()); this.cannedAcl = getCannedAcl(context.cannedAcl()); this.key = requireNonNull(key, "key is null"); - + this.buffer = new LinkedBuffer(INITIAL_BUFFER_SIZE, MAXIMUM_BUFFER_SIZE); + updateMemory(); verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration"); } - @SuppressWarnings("NumericCastThatLosesPrecision") @Override public void write(int b) throws IOException { ensureOpen(); - ensureCapacity(1); - buffer[bufferSize] = (byte) b; - bufferSize++; + buffer.write(b); + updateMemory(); flushBuffer(false); } @@ -125,19 +134,23 @@ public void write(int b) public void write(byte[] bytes, int offset, int length) throws IOException { + requireNonNull(bytes); + checkFromIndexSize(offset, length, bytes.length); ensureOpen(); + // make sure we don't exceed the part size while (length > 0) { - ensureCapacity(length); - - int copied = min(buffer.length - bufferSize, length); - arraycopy(bytes, offset, buffer, bufferSize, copied); - bufferSize += copied; - + int capacity = partSize - buffer.size(); + if (capacity >= length) { + buffer.write(bytes, offset, length); + updateMemory(); + break; + } + buffer.write(bytes, offset, capacity); + updateMemory(); flushBuffer(false); - - offset += copied; - length -= copied; + offset += capacity; + length -= capacity; } } @@ -170,13 +183,15 @@ public void close() try { flushBuffer(true); - memoryContext.close(); waitForPreviousUploadFinish(); } catch (IOException | RuntimeException e) { abortUploadSuppressed(e); throw e; } + finally { + memoryContext.close(); + } try { uploadId.ifPresent(this::finishUpload); @@ -189,6 +204,10 @@ public void close() } throw new IOException(e); } + finally { + buffer = null; + memoryContext.close(); + } } private void ensureOpen() @@ -199,23 +218,10 @@ private void ensureOpen() } } - private void ensureCapacity(int extra) - { - int capacity = min(partSize, bufferSize + extra); - if (buffer.length < capacity) { - int target = max(buffer.length, initialBufferSize); - if (target < capacity) { - target += target / 2; // increase 50% - target = clamp(target, capacity, partSize); - } - buffer = Arrays.copyOf(buffer, target); - memoryContext.setBytes(buffer.length); - } - } - private void flushBuffer(boolean finished) throws IOException { + DataStreamProvider dataStreamProvider = buffer; // skip multipart upload if there would only be one part if (finished && !multipartUploadStarted) { try { @@ -225,9 +231,9 @@ private void flushBuffer(boolean finished) location, key, false, - buffer, - 0, - bufferSize); + dataStreamProvider); + buffer = null; + updateMemory(); return; } catch (Throwable e) { @@ -237,20 +243,7 @@ private void flushBuffer(boolean finished) } // the multipart upload API only allows the last part to be smaller than 5MB - if ((bufferSize == partSize) || (finished && (bufferSize > 0))) { - byte[] data = buffer; - int length = bufferSize; - - if (finished) { - this.buffer = null; - } - else { - this.buffer = new byte[0]; - this.initialBufferSize = partSize; - bufferSize = 0; - } - memoryContext.setBytes(0); - + if ((buffer.size() >= partSize) || (finished && (buffer.size() > 0))) { try { waitForPreviousUploadFinish(); } @@ -259,8 +252,12 @@ private void flushBuffer(boolean finished) abortUploadSuppressed(e); throw e; } + multipartUploadStarted = true; - inProgressUploadFuture = supplyAsync(() -> uploadPage(data, length), uploadExecutor); + inProgressUploadFuture = supplyAsync(() -> uploadPage(dataStreamProvider), uploadExecutor); + inFlightBytes = buffer.allocated(); + buffer = finished ? null : buffer.startNextPart(); + updateMemory(); } } @@ -282,9 +279,12 @@ void waitForPreviousUploadFinish() catch (ExecutionException e) { throw new IOException("Streaming upload failed", e); } + finally { + inFlightBytes = 0; + } } - private CompletedPart uploadPage(byte[] data, int length) + private CompletedPart uploadPage(DataStreamProvider dataStreamProvider) { if (uploadId.isEmpty()) { CreateMultipartUploadRequest request = CreateMultipartUploadRequest.builder() @@ -312,7 +312,7 @@ private CompletedPart uploadPage(byte[] data, int length) .requestPayer(requestPayer) .bucket(location.bucket()) .key(location.key()) - .contentLength((long) length) + .contentLength((long) dataStreamProvider.size()) .uploadId(uploadId.get()) .partNumber(currentPartNumber) .applyMutation(builder -> @@ -324,9 +324,7 @@ private CompletedPart uploadPage(byte[] data, int length) () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); - ByteBuffer bytes = ByteBuffer.wrap(data, 0, length); - - UploadPartResponse response = client.uploadPart(request, RequestBody.fromByteBuffer(bytes)); + UploadPartResponse response = client.uploadPart(request, RequestBody.fromContentProvider(dataStreamProvider::takeInputStream, dataStreamProvider.size(), MIMETYPE_OCTET_STREAM)); CompletedPart part = CompletedPart.builder() .partNumber(currentPartNumber) @@ -382,19 +380,20 @@ private void abortUploadSuppressed(Throwable throwable) } } + private void updateMemory() + { + memoryContext.setBytes((buffer == null ? 0 : buffer.allocated()) + inFlightBytes); + } + static void putObject( S3Client client, S3Context context, S3Location location, Optional key, boolean exclusiveCreate, - byte[] data, - int dataOffset, - int dataLength) + DataStreamProvider dataStreamProvider) throws IOException { - checkFromIndexSize(dataOffset, dataLength, data.length); - PutObjectRequest request = PutObjectRequest.builder() .overrideConfiguration(context::applyCredentialProviderOverride) .acl(getCannedAcl(context.cannedAcl())) @@ -402,7 +401,7 @@ static void putObject( .bucket(location.bucket()) .key(location.key()) .storageClass(toStorageClass(context.storageClass())) - .contentLength((long) dataLength) + .contentLength((long) dataStreamProvider.size()) .applyMutation(builder -> { if (exclusiveCreate) { builder.ifNoneMatch("*"); @@ -416,10 +415,8 @@ static void putObject( }) .build(); - ByteBuffer bytes = ByteBuffer.wrap(data, dataOffset, dataLength); - try { - client.putObject(request, RequestBody.fromByteBuffer(bytes)); + client.putObject(request, RequestBody.fromContentProvider(dataStreamProvider::takeInputStream, dataStreamProvider.size(), MIMETYPE_OCTET_STREAM)); } catch (SdkException putObjectException) { // When `location` already exists, the operation will fail with `412 Precondition Failed` @@ -439,4 +436,153 @@ static void putObject( throw new IOException("Put failed for bucket [%s] key [%s]: %s".formatted(location.bucket(), location.key(), putObjectException), putObjectException); } } + + interface DataStreamProvider + { + InputStream takeInputStream(); + + int allocated(); + + int size(); + } + + record ByteArrayStreamProvider(byte[] data) + implements DataStreamProvider + { + @Override + public InputStream takeInputStream() + { + return new ByteArrayInputStream(data); + } + + @Override + public int allocated() + { + return data.length; + } + + @Override + public int size() + { + return data.length; + } + } + + @VisibleForTesting + static class LinkedBuffer + implements DataStreamProvider + { + final int initialBufferSize; + final int maxBufferSize; + List parts; + int currentBufferSize; + byte[] currentBuffer; + int currentOffset; + int allocatedSize; + int totalSize; + + public LinkedBuffer(int initialBufferSize, int maxBufferSize) + { + checkArgument(initialBufferSize <= maxBufferSize, "initialBufferSize must be less than or equal to maxBufferSize"); + this.initialBufferSize = initialBufferSize; + this.maxBufferSize = maxBufferSize; + this.parts = new ArrayList<>(); + this.currentBufferSize = initialBufferSize; + this.currentBuffer = new byte[initialBufferSize]; + this.currentOffset = 0; + this.allocatedSize = initialBufferSize; + this.totalSize = 0; + } + + @SuppressWarnings("NumericCastThatLosesPrecision") + public void write(int b) + { + if (remainingCapacity() == 0) { + parts.add(currentBuffer); + resetBuffer(); + } + currentBuffer[currentOffset] = (byte) b; + currentOffset++; + totalSize++; + } + + public void write(byte[] bytes, int offset, int length) + { + writeInternal(bytes, offset, length); + } + + @Override + public int allocated() + { + return allocatedSize; + } + + @Override + public int size() + { + return totalSize; + } + + @Override + public InputStream takeInputStream() + { + ImmutableList.Builder streams = ImmutableList.builderWithExpectedSize(parts.size() + ((currentOffset == 0) ? 0 : 1)); + for (byte[] part : parts) { + streams.add(new ByteArrayInputStream(part)); + } + if (currentOffset > 0) { + streams.add(new ByteArrayInputStream(currentBuffer, 0, currentOffset)); + } + return new SequenceInputStream(enumeration(streams.build())); + } + + public LinkedBuffer startNextPart() + { + return new LinkedBuffer(currentBufferSize, maxBufferSize); + } + + @VisibleForTesting + void reset() + { + allocatedSize = initialBufferSize; + currentBufferSize = initialBufferSize; + currentBuffer = new byte[currentBufferSize]; + currentOffset = 0; + totalSize = 0; + parts = new ArrayList<>(); + } + + private int remainingCapacity() + { + return currentBufferSize - currentOffset; + } + + private void resetBuffer() + { + // scale buffer capacity by 2 or up to the maximum allowed + currentBufferSize = min(currentBufferSize * 2, maxBufferSize); + currentBuffer = new byte[currentBufferSize]; + currentOffset = 0; + allocatedSize += currentBufferSize; + } + + private void writeInternal(byte[] srcBytes, int srcOffset, int srcLength) + { + while (srcLength > 0) { + if (remainingCapacity() == 0) { + parts.add(currentBuffer); + resetBuffer(); + } + + int bytesToWrite = min(srcLength, remainingCapacity()); + + arraycopy(srcBytes, srcOffset, currentBuffer, currentOffset, bytesToWrite); + + currentOffset += bytesToWrite; + totalSize += bytesToWrite; + srcOffset += bytesToWrite; + srcLength -= bytesToWrite; + } + } + } } diff --git a/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestLinkedBuffer.java b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestLinkedBuffer.java new file mode 100644 index 000000000000..2f49a42d85bc --- /dev/null +++ b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestLinkedBuffer.java @@ -0,0 +1,226 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.filesystem.s3; + +import io.trino.filesystem.s3.S3OutputStream.LinkedBuffer; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@SuppressWarnings("resource") +public class TestLinkedBuffer +{ + @Test + void testInvalidConfiguration() + { + assertThatThrownBy(() -> new LinkedBuffer(100, 10)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("initialBufferSize must be less than or equal to maxBufferSize"); + } + + @Test + void testEqualConfiguration() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 10); + assertThat(buffer.size()).isZero(); + assertThat(buffer.takeInputStream().readAllBytes()).isEmpty(); + } + + @Test + void testEmptyBuffer() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + + assertThat(buffer.size()).isZero(); + assertThat(buffer.takeInputStream().readAllBytes()).isEmpty(); + } + + @Test + void testSingleWrite() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + + buffer.write(42); + + assertThat(buffer.size()).isEqualTo(1); + assertThat(buffer.takeInputStream().readAllBytes()).containsExactly(42); + } + + @Test + void testMultipleWrite() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 50); + byte[] chunk1 = {1, 2, 3}; + byte[] chunk2 = {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; + + buffer.write(chunk1, 0, chunk1.length); + buffer.write(chunk2, 0, chunk2.length); + + assertThat(buffer.size()).isEqualTo(14); + + byte[] expected = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; + assertThat(buffer.takeInputStream().readAllBytes()).isEqualTo(expected); + } + + @Test + void testSmallWriteWithOffsetAndLength() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + byte[] source = {0, 0, 1, 2, 3, 0, 0}; + + buffer.write(source, 2, 3); + + assertThat(buffer.size()).isEqualTo(3); + assertThat(buffer.takeInputStream().readAllBytes()).containsExactly(1, 2, 3); + } + + @Test + void testLargeWriteWithOffsetAndLength() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + byte[] source = testBytes(1000); + + buffer.write(source, 300, 700); + + assertThat(buffer.size()).isEqualTo(700); + assertThat(buffer.takeInputStream().readAllBytes()).isEqualTo(Arrays.copyOfRange(source, 300, 1000)); + } + + @Test + void testBufferExpansion() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + byte[] data = testBytes(25); + + buffer.write(data, 0, data.length); + + // writing 25 bytes should trigger expansion + assertThat(buffer.parts.size()).isEqualTo(1); + assertThat(buffer.size()).isEqualTo(25); + assertThat(buffer.takeInputStream().readAllBytes()).isEqualTo(data); + } + + @Test + void testMaxBufferSizeCeiling() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 30); + byte[] data = testBytes(100); + + buffer.write(data, 0, data.length); + + // writing 100 should hit max buffer size + assertThat(buffer.currentBufferSize).isEqualTo(30); + assertThat(buffer.size()).isEqualTo(100); + assertThat(buffer.takeInputStream().readAllBytes()).isEqualTo(data); + } + + @Test + void testReset() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 100); + byte[] data = {1, 2, 3, 4, 5}; + + buffer.write(data, 0, data.length); + assertThat(buffer.size()).isEqualTo(5); + + buffer.reset(); + + assertThat(buffer.size()).isZero(); + assertThat(buffer.takeInputStream().readAllBytes()).isEmpty(); + + buffer.write(data, 0, data.length); + assertThat(buffer.size()).isEqualTo(5); + assertThat(buffer.takeInputStream().readAllBytes()).isEqualTo(data); + } + + @Test + void testTakeInputStream() + throws IOException + { + LinkedBuffer buffer = new LinkedBuffer(10, 30); + byte[] data = testBytes(50); + + buffer.write(data, 0, data.length); + + byte[] firstRead = buffer.takeInputStream().readAllBytes(); + byte[] secondRead = buffer.takeInputStream().readAllBytes(); + + assertThat(firstRead).isEqualTo(data); + assertThat(secondRead).isEqualTo(data); + } + + @Test + void testAllocatedSize() + { + LinkedBuffer buffer = new LinkedBuffer(10, 50); + byte[] data = testBytes(50); + + buffer.write(data, 0, data.length); + + assertThat(buffer.size()).isEqualTo(data.length); + // internal buffer doubles twice + // 10 + 20 + 40 = 70 + assertThat(buffer.allocated()).isEqualTo(70); + } + + @Test + void testLazyBufferAllocation() + { + LinkedBuffer buffer = new LinkedBuffer(10, 50); + byte[] data = testBytes(30); + + buffer.write(data, 0, data.length); + + assertThat(buffer.size()).isEqualTo(data.length); + // internal buffer doubles twice, but should not allocate until the next write + // 10 + 20 = 30 + assertThat(buffer.allocated()).isEqualTo(30); + } + + @Test + void testAllocatedSizeMaxBufferSize() + { + LinkedBuffer buffer = new LinkedBuffer(10, 30); + byte[] data = testBytes(50); + + buffer.write(data, 0, data.length); + + assertThat(buffer.size()).isEqualTo(data.length); + // internal buffer doubles twice, but hits the maximum buffer size + // 10 + 20 + 30 = 60 + assertThat(buffer.allocated()).isEqualTo(60); + } + + private static byte[] testBytes(int length) + { + byte[] result = new byte[length]; + for (int i = 0; i < length; i++) { + result[i] = (byte) (i % 256); + } + return result; + } +}