diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index 92b464e97a59..6cb77576a8f3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -56,9 +56,11 @@ import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Supplier; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Suppliers.memoize; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters; import static java.util.Objects.requireNonNull; @@ -77,7 +79,7 @@ public final class SqlStage { private final Session session; private final StageStateMachine stateMachine; - private final Map tableCredentials; + private final Supplier> tableCredentialsProvider; private final RemoteTaskFactory remoteTaskFactory; private final NodeTaskMap nodeTaskMap; private final boolean summarizeTaskInfo; @@ -135,7 +137,7 @@ public static SqlStage createSqlStage( nodeTaskMap, summarizeTaskInfo, bucketCountProvider, - extractTableCredentials(session, metadata, fragment)); + memoize(() -> extractTableCredentials(session, metadata, fragment))); sqlStage.initialize(); return sqlStage; } @@ -147,7 +149,7 @@ private SqlStage( NodeTaskMap nodeTaskMap, boolean summarizeTaskInfo, LocalExchangeBucketCountProvider bucketCountProvider, - Map tableCredentials) + Supplier> tableCredentialsProvider) { this.session = requireNonNull(session, "session is null"); this.stateMachine = stateMachine; @@ -157,7 +159,7 @@ private SqlStage( this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null"); this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); - this.tableCredentials = ImmutableMap.copyOf(tableCredentials); + this.tableCredentialsProvider = requireNonNull(tableCredentialsProvider, "tableCredentialsProvider is null"); } private static Map extractTableCredentials(Session session, Metadata metadata, PlanFragment fragment) @@ -308,7 +310,7 @@ public synchronized Optional createTask( node, speculative, fragment, - tableCredentials, + tableCredentialsProvider.get(), splits, outputBuffers, nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index 737de712baca..f375c0bd8cf2 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -17,20 +17,31 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.DataSize; import io.opentelemetry.api.trace.Span; +import io.trino.Session; +import io.trino.connector.CatalogHandle; import io.trino.cost.StatsAndCosts; +import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; +import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.PipelinedOutputBuffers; import io.trino.execution.scheduler.SplitSchedulerStats; +import io.trino.metadata.AbstractMockMetadata; +import io.trino.metadata.Metadata; import io.trino.metadata.Split; import io.trino.node.InternalNode; import io.trino.operator.RetryPolicy; import io.trino.spi.NodeVersion; import io.trino.spi.QueryId; +import io.trino.spi.connector.ConnectorTableCredentials; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; @@ -48,18 +59,25 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.tracing.Tracing.noopTracer; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.execution.SqlStage.createSqlStage; +import static io.trino.execution.TaskTestUtils.PLAN_FRAGMENT; +import static io.trino.execution.TaskTestUtils.SPLIT; +import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; import static io.trino.execution.buffer.PipelinedOutputBuffers.BufferType.ARBITRARY; import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -110,6 +128,132 @@ public void testFinalStageInfo() } } + @Test + public void testTableCredentialsAreLazilyResolved() + { + AtomicInteger callCount = new AtomicInteger(); + Metadata metadata = new AbstractMockMetadata() + { + @Override + public Optional getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle) + { + callCount.incrementAndGet(); + return Optional.empty(); + } + }; + + SqlStage stage = createSqlStage( + metadata, + new StageId(new QueryId("query"), 0), + createScanPlanFragment(), + ImmutableMap.of(), + new MockRemoteTaskFactory(executor, scheduledExecutor), + TEST_SESSION, + true, + new NodeTaskMap(new FinalizerService()), + executor, + noopTracer(), + Span.getInvalid(), + new SplitSchedulerStats(), + (_, _) -> OptionalInt.empty()); + + // Extraction does not run at stage construction + assertThat(callCount).hasValue(0); + + // First createTask triggers extraction + createTestTask(stage, 0); + assertThat(callCount).hasValue(1); + + // Subsequent createTasks reuse the memoized result + createTestTask(stage, 1); + createTestTask(stage, 2); + assertThat(callCount).hasValue(1); + + stage.finish(); + } + + @Test + public void testTableCredentialsSeeLateUpdate() + { + AtomicReference> response = new AtomicReference<>(Optional.empty()); + Metadata metadata = new AbstractMockMetadata() + { + @Override + public Optional getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle) + { + return response.get(); + } + }; + + AtomicReference> capturedCredentials = new AtomicReference<>(); + RemoteTaskFactory capturingFactory = new RemoteTaskFactory() + { + private final RemoteTaskFactory delegate = new MockRemoteTaskFactory(executor, scheduledExecutor); + + @Override + public RemoteTask createRemoteTask( + Session session, + Span stageSpan, + TaskId taskId, + InternalNode node, + boolean speculative, + PlanFragment fragment, + Map tableCredentials, + Multimap initialSplits, + OutputBuffers outputBuffers, + PartitionedSplitCountTracker partitionedSplitCountTracker, + Set outboundDynamicFilterIds, + Optional estimatedMemory, + boolean summarizeTaskInfo) + { + capturedCredentials.set(tableCredentials); + return delegate.createRemoteTask(session, stageSpan, taskId, node, speculative, fragment, tableCredentials, initialSplits, outputBuffers, partitionedSplitCountTracker, outboundDynamicFilterIds, estimatedMemory, summarizeTaskInfo); + } + }; + + SqlStage stage = createSqlStage( + metadata, + new StageId(new QueryId("query"), 0), + createScanPlanFragment(), + ImmutableMap.of(), + capturingFactory, + TEST_SESSION, + true, + new NodeTaskMap(new FinalizerService()), + executor, + noopTracer(), + Span.getInvalid(), + new SplitSchedulerStats(), + (_, _) -> OptionalInt.empty()); + + // Simulate a split source populating credentials after stage construction but before first task creation. + ConnectorTableCredentials credentials = new ConnectorTableCredentials() {}; + response.set(Optional.of(credentials)); + + createTestTask(stage, 0); + + assertThat(capturedCredentials.get()) + .containsExactly(Map.entry(TABLE_SCAN_NODE_ID, credentials)); + + stage.finish(); + } + + private void createTestTask(SqlStage stage, int partition) + { + InternalNode node = new InternalNode("node" + partition, URI.create("http://node/" + partition), NodeVersion.UNKNOWN, false); + stage.createTask( + node, + partition, + 0, + Optional.empty(), + OptionalInt.empty(), + PipelinedOutputBuffers.createInitial(ARBITRARY), + ImmutableListMultimap.of(TABLE_SCAN_NODE_ID, SPLIT.split()), + ImmutableSet.of(), + Optional.empty(), + false); + } + private void testFinalStageInfoInternal() throws Exception { @@ -255,4 +399,9 @@ private static PlanFragment createExchangePlanFragment() ImmutableMap.of(), Optional.empty()); } + + private static PlanFragment createScanPlanFragment() + { + return PLAN_FRAGMENT.withOutputPartitioning(Optional.empty(), OptionalInt.empty()); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index c947ce3712cc..25d2b8d12a80 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -18,7 +18,6 @@ import com.google.common.base.Splitter.MapSplitter; import com.google.common.base.Suppliers; import com.google.common.base.VerifyException; -import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -26,14 +25,12 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.common.collect.Streams; -import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; import io.airlift.units.Duration; -import io.trino.cache.NonEvictableCache; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; import io.trino.metastore.Column; @@ -247,7 +244,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -258,8 +254,6 @@ import static com.google.common.collect.Maps.transformValues; import static com.google.common.collect.Sets.difference; import static io.airlift.units.Duration.ZERO; -import static io.trino.cache.CacheUtils.uncheckedCacheGet; -import static io.trino.cache.SafeCaches.buildNonEvictableCache; import static io.trino.filesystem.Locations.isS3Tables; import static io.trino.plugin.base.filter.UtcConstraintExtractor.extractTupleDomain; import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; @@ -512,7 +506,7 @@ public class IcebergMetadata private final int materializedViewRefreshMaxSnapshotsToExpire; private final Duration materializedViewRefreshSnapshotRetentionPeriod; private final Map> tableStatisticsCache = new ConcurrentHashMap<>(); - private final NonEvictableCache tableCredentialsCache; + private final IcebergTableCredentialsProvider tableCredentialsProvider; private final DeletionVectorWriter deletionVectorWriter; private Transaction transaction; @@ -552,42 +546,25 @@ public IcebergMetadata( this.deletionVectorWriter = requireNonNull(deletionVectorWriter, "deletionVectorWriter is null"); this.materializedViewRefreshMaxSnapshotsToExpire = materializedViewRefreshMaxSnapshotsToExpire; this.materializedViewRefreshSnapshotRetentionPeriod = materializedViewRefreshSnapshotRetentionPeriod; - this.tableCredentialsCache = buildNonEvictableCache(CacheBuilder.newBuilder()); + this.tableCredentialsProvider = new IcebergTableCredentialsProvider(catalog); } @Override public Optional getTableCredentials(ConnectorSession session, ConnectorTableHandle tableHandle) { - return getOrLoadTableCredentials(session, getSchemaTableName(tableHandle)); + return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableHandle)); } @Override public Optional getTableCredentials(ConnectorSession session, ConnectorWritableTableHandle tableHandle) { - return getOrLoadTableCredentials(session, getSchemaTableName(tableHandle)); + return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableHandle)); } @Override public Optional getTableCredentials(ConnectorSession session, ConnectorTableFunctionHandle tableFunctionHandle) { - return getOrLoadTableCredentials(session, getSchemaTableName(tableFunctionHandle)); - } - - private Optional getOrLoadTableCredentials(ConnectorSession session, SchemaTableName schemaTableName) - { - try { - return Optional.of(uncheckedCacheGet( - tableCredentialsCache, - schemaTableName, - () -> { - BaseTable baseTable = catalog.loadTable(session, schemaTableName); - return new IcebergTableCredentials(baseTable.io().properties()); - })); - } - catch (UncheckedExecutionException e) { - throwIfUnchecked(e.getCause()); - throw e; - } + return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableFunctionHandle)); } private static SchemaTableName getSchemaTableName(ConnectorTableHandle tableHandle) @@ -1599,7 +1576,7 @@ private List getChildNamespaces(ConnectorSession session, String parentN private IcebergWritableTableHandle newWritableTableHandle(SchemaTableName name, Table table) { - tableCredentialsCache.put(name, IcebergTableCredentials.forFileIO(table.io())); + tableCredentialsProvider.putTableCredentials(name, IcebergTableCredentials.forFileIO(table.io())); SortFieldInfo sortInfo = getSupportedSortFields(table.schema(), table.sortOrder()); return new IcebergWritableTableHandle( name, diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableCredentialsProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableCredentialsProvider.java new file mode 100644 index 000000000000..54856ad86929 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableCredentialsProvider.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import io.trino.plugin.iceberg.catalog.TrinoCatalog; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableCredentials; +import io.trino.spi.connector.SchemaTableName; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import static java.util.Objects.requireNonNull; + +public class IcebergTableCredentialsProvider +{ + private final TrinoCatalog catalog; + private final Map tableCredentials = new ConcurrentHashMap<>(); + + public IcebergTableCredentialsProvider(TrinoCatalog catalog) + { + this.catalog = requireNonNull(catalog, "catalog is null"); + } + + public Optional getTableCredentials(ConnectorSession session, SchemaTableName schemaTableName) + { + return Optional.of(tableCredentials.computeIfAbsent(schemaTableName, key -> + new IcebergTableCredentials(catalog.loadTable(session, key).io().properties()))); + } + + public void putTableCredentials(SchemaTableName schemaTableName, IcebergTableCredentials credentials) + { + tableCredentials.put(schemaTableName, credentials); + } +}