Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions core/trino-main/src/main/java/io/trino/execution/SqlStage.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Suppliers.memoize;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters;
import static java.util.Objects.requireNonNull;
Expand All @@ -77,7 +79,7 @@ public final class SqlStage
{
private final Session session;
private final StageStateMachine stateMachine;
private final Map<PlanNodeId, ConnectorTableCredentials> tableCredentials;
private final Supplier<Map<PlanNodeId, ConnectorTableCredentials>> tableCredentialsProvider;
private final RemoteTaskFactory remoteTaskFactory;
private final NodeTaskMap nodeTaskMap;
private final boolean summarizeTaskInfo;
Expand Down Expand Up @@ -135,7 +137,7 @@ public static SqlStage createSqlStage(
nodeTaskMap,
summarizeTaskInfo,
bucketCountProvider,
extractTableCredentials(session, metadata, fragment));
memoize(() -> extractTableCredentials(session, metadata, fragment)));
sqlStage.initialize();
return sqlStage;
}
Expand All @@ -147,7 +149,7 @@ private SqlStage(
NodeTaskMap nodeTaskMap,
boolean summarizeTaskInfo,
LocalExchangeBucketCountProvider bucketCountProvider,
Map<PlanNodeId, ConnectorTableCredentials> tableCredentials)
Supplier<Map<PlanNodeId, ConnectorTableCredentials>> tableCredentialsProvider)
{
this.session = requireNonNull(session, "session is null");
this.stateMachine = stateMachine;
Expand All @@ -157,7 +159,7 @@ private SqlStage(
this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null");

this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment());
this.tableCredentials = ImmutableMap.copyOf(tableCredentials);
this.tableCredentialsProvider = requireNonNull(tableCredentialsProvider, "tableCredentialsProvider is null");
}

private static Map<PlanNodeId, ConnectorTableCredentials> extractTableCredentials(Session session, Metadata metadata, PlanFragment fragment)
Expand Down Expand Up @@ -308,7 +310,7 @@ public synchronized Optional<RemoteTask> createTask(
node,
speculative,
fragment,
tableCredentials,
tableCredentialsProvider.get(),
splits,
outputBuffers,
nodeTaskMap.createPartitionedSplitCountTracker(node, taskId),
Expand Down
149 changes: 149 additions & 0 deletions core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,31 @@
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.units.DataSize;
import io.opentelemetry.api.trace.Span;
import io.trino.Session;
import io.trino.connector.CatalogHandle;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.Metadata;
import io.trino.metadata.Split;
import io.trino.node.InternalNode;
import io.trino.operator.RetryPolicy;
import io.trino.spi.NodeVersion;
import io.trino.spi.QueryId;
import io.trino.spi.connector.ConnectorTableCredentials;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
Expand All @@ -48,18 +59,25 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.tracing.Tracing.noopTracer;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.execution.SqlStage.createSqlStage;
import static io.trino.execution.TaskTestUtils.PLAN_FRAGMENT;
import static io.trino.execution.TaskTestUtils.SPLIT;
import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID;
import static io.trino.execution.buffer.PipelinedOutputBuffers.BufferType.ARBITRARY;
import static io.trino.metadata.AbstractMockMetadata.dummyMetadata;
import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
Expand Down Expand Up @@ -110,6 +128,132 @@ public void testFinalStageInfo()
}
}

@Test
public void testTableCredentialsAreLazilyResolved()
{
AtomicInteger callCount = new AtomicInteger();
Metadata metadata = new AbstractMockMetadata()
{
@Override
public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle)
{
callCount.incrementAndGet();
return Optional.empty();
}
};

SqlStage stage = createSqlStage(
metadata,
new StageId(new QueryId("query"), 0),
createScanPlanFragment(),
ImmutableMap.of(),
new MockRemoteTaskFactory(executor, scheduledExecutor),
TEST_SESSION,
true,
new NodeTaskMap(new FinalizerService()),
executor,
noopTracer(),
Span.getInvalid(),
new SplitSchedulerStats(),
(_, _) -> OptionalInt.empty());

// Extraction does not run at stage construction
assertThat(callCount).hasValue(0);

// First createTask triggers extraction
createTestTask(stage, 0);
assertThat(callCount).hasValue(1);

// Subsequent createTasks reuse the memoized result
createTestTask(stage, 1);
createTestTask(stage, 2);
assertThat(callCount).hasValue(1);

stage.finish();
}

@Test
public void testTableCredentialsSeeLateUpdate()
{
AtomicReference<Optional<ConnectorTableCredentials>> response = new AtomicReference<>(Optional.empty());
Metadata metadata = new AbstractMockMetadata()
{
@Override
public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle)
{
return response.get();
}
};

AtomicReference<Map<PlanNodeId, ConnectorTableCredentials>> capturedCredentials = new AtomicReference<>();
RemoteTaskFactory capturingFactory = new RemoteTaskFactory()
{
private final RemoteTaskFactory delegate = new MockRemoteTaskFactory(executor, scheduledExecutor);

@Override
public RemoteTask createRemoteTask(
Session session,
Span stageSpan,
TaskId taskId,
InternalNode node,
boolean speculative,
PlanFragment fragment,
Map<PlanNodeId, ConnectorTableCredentials> tableCredentials,
Multimap<PlanNodeId, Split> initialSplits,
OutputBuffers outputBuffers,
PartitionedSplitCountTracker partitionedSplitCountTracker,
Set<DynamicFilterId> outboundDynamicFilterIds,
Optional<DataSize> estimatedMemory,
boolean summarizeTaskInfo)
{
capturedCredentials.set(tableCredentials);
return delegate.createRemoteTask(session, stageSpan, taskId, node, speculative, fragment, tableCredentials, initialSplits, outputBuffers, partitionedSplitCountTracker, outboundDynamicFilterIds, estimatedMemory, summarizeTaskInfo);
}
};

SqlStage stage = createSqlStage(
metadata,
new StageId(new QueryId("query"), 0),
createScanPlanFragment(),
ImmutableMap.of(),
capturingFactory,
TEST_SESSION,
true,
new NodeTaskMap(new FinalizerService()),
executor,
noopTracer(),
Span.getInvalid(),
new SplitSchedulerStats(),
(_, _) -> OptionalInt.empty());

// Simulate a split source populating credentials after stage construction but before first task creation.
ConnectorTableCredentials credentials = new ConnectorTableCredentials() {};
response.set(Optional.of(credentials));

createTestTask(stage, 0);

assertThat(capturedCredentials.get())
.containsExactly(Map.entry(TABLE_SCAN_NODE_ID, credentials));

stage.finish();
}

private void createTestTask(SqlStage stage, int partition)
{
InternalNode node = new InternalNode("node" + partition, URI.create("http://node/" + partition), NodeVersion.UNKNOWN, false);
stage.createTask(
node,
partition,
0,
Optional.empty(),
OptionalInt.empty(),
PipelinedOutputBuffers.createInitial(ARBITRARY),
ImmutableListMultimap.of(TABLE_SCAN_NODE_ID, SPLIT.split()),
ImmutableSet.of(),
Optional.empty(),
false);
}

private void testFinalStageInfoInternal()
throws Exception
{
Expand Down Expand Up @@ -255,4 +399,9 @@ private static PlanFragment createExchangePlanFragment()
ImmutableMap.of(),
Optional.empty());
}

private static PlanFragment createScanPlanFragment()
{
return PLAN_FRAGMENT.withOutputPartitioning(Optional.empty(), OptionalInt.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@
import com.google.common.base.Splitter.MapSplitter;
import com.google.common.base.Suppliers;
import com.google.common.base.VerifyException;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.cache.NonEvictableCache;
import io.trino.filesystem.Location;
import io.trino.filesystem.TrinoFileSystem;
import io.trino.metastore.Column;
Expand Down Expand Up @@ -247,7 +244,6 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -258,8 +254,6 @@
import static com.google.common.collect.Maps.transformValues;
import static com.google.common.collect.Sets.difference;
import static io.airlift.units.Duration.ZERO;
import static io.trino.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.filesystem.Locations.isS3Tables;
import static io.trino.plugin.base.filter.UtcConstraintExtractor.extractTupleDomain;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
Expand Down Expand Up @@ -512,7 +506,7 @@ public class IcebergMetadata
private final int materializedViewRefreshMaxSnapshotsToExpire;
private final Duration materializedViewRefreshSnapshotRetentionPeriod;
private final Map<IcebergTableHandle, AtomicReference<TableStatistics>> tableStatisticsCache = new ConcurrentHashMap<>();
private final NonEvictableCache<SchemaTableName, IcebergTableCredentials> tableCredentialsCache;
private final IcebergTableCredentialsProvider tableCredentialsProvider;
private final DeletionVectorWriter deletionVectorWriter;

private Transaction transaction;
Expand Down Expand Up @@ -552,42 +546,25 @@ public IcebergMetadata(
this.deletionVectorWriter = requireNonNull(deletionVectorWriter, "deletionVectorWriter is null");
this.materializedViewRefreshMaxSnapshotsToExpire = materializedViewRefreshMaxSnapshotsToExpire;
this.materializedViewRefreshSnapshotRetentionPeriod = materializedViewRefreshSnapshotRetentionPeriod;
this.tableCredentialsCache = buildNonEvictableCache(CacheBuilder.newBuilder());
this.tableCredentialsProvider = new IcebergTableCredentialsProvider(catalog);
}

@Override
public Optional<ConnectorTableCredentials> getTableCredentials(ConnectorSession session, ConnectorTableHandle tableHandle)
{
return getOrLoadTableCredentials(session, getSchemaTableName(tableHandle));
return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableHandle));
}

@Override
public Optional<ConnectorTableCredentials> getTableCredentials(ConnectorSession session, ConnectorWritableTableHandle tableHandle)
{
return getOrLoadTableCredentials(session, getSchemaTableName(tableHandle));
return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableHandle));
}

@Override
public Optional<ConnectorTableCredentials> getTableCredentials(ConnectorSession session, ConnectorTableFunctionHandle tableFunctionHandle)
{
return getOrLoadTableCredentials(session, getSchemaTableName(tableFunctionHandle));
}

private Optional<ConnectorTableCredentials> getOrLoadTableCredentials(ConnectorSession session, SchemaTableName schemaTableName)
{
try {
return Optional.of(uncheckedCacheGet(
tableCredentialsCache,
schemaTableName,
() -> {
BaseTable baseTable = catalog.loadTable(session, schemaTableName);
return new IcebergTableCredentials(baseTable.io().properties());
}));
}
catch (UncheckedExecutionException e) {
throwIfUnchecked(e.getCause());
throw e;
}
return tableCredentialsProvider.getTableCredentials(session, getSchemaTableName(tableFunctionHandle));
}

private static SchemaTableName getSchemaTableName(ConnectorTableHandle tableHandle)
Expand Down Expand Up @@ -1599,7 +1576,7 @@ private List<String> getChildNamespaces(ConnectorSession session, String parentN

private IcebergWritableTableHandle newWritableTableHandle(SchemaTableName name, Table table)
{
tableCredentialsCache.put(name, IcebergTableCredentials.forFileIO(table.io()));
tableCredentialsProvider.putTableCredentials(name, IcebergTableCredentials.forFileIO(table.io()));
SortFieldInfo sortInfo = getSupportedSortFields(table.schema(), table.sortOrder());
return new IcebergWritableTableHandle(
name,
Expand Down
Loading