Skip to content

Commit 3c08065

Browse files
Resolve credentials lazily in SqlStage
Use a Supplier to resolve creds lazily in SqlStage. This allows the use of creds that are created post-stage creation, e.g. per-split vended credentials.
1 parent 676f043 commit 3c08065

2 files changed

Lines changed: 155 additions & 5 deletions

File tree

core/trino-main/src/main/java/io/trino/execution/SqlStage.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@
5656
import java.util.concurrent.Executor;
5757
import java.util.concurrent.TimeUnit;
5858
import java.util.function.Function;
59+
import java.util.function.Supplier;
5960

6061
import static com.google.common.base.MoreObjects.toStringHelper;
6162
import static com.google.common.base.Preconditions.checkArgument;
63+
import static com.google.common.base.Suppliers.memoize;
6264
import static com.google.common.collect.ImmutableList.toImmutableList;
6365
import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters;
6466
import static java.util.Objects.requireNonNull;
@@ -77,7 +79,7 @@ public final class SqlStage
7779
{
7880
private final Session session;
7981
private final StageStateMachine stateMachine;
80-
private final Map<PlanNodeId, ConnectorTableCredentials> tableCredentials;
82+
private final Supplier<Map<PlanNodeId, ConnectorTableCredentials>> tableCredentialsProvider;
8183
private final RemoteTaskFactory remoteTaskFactory;
8284
private final NodeTaskMap nodeTaskMap;
8385
private final boolean summarizeTaskInfo;
@@ -135,7 +137,7 @@ public static SqlStage createSqlStage(
135137
nodeTaskMap,
136138
summarizeTaskInfo,
137139
bucketCountProvider,
138-
extractTableCredentials(session, metadata, fragment));
140+
memoize(() -> extractTableCredentials(session, metadata, fragment)));
139141
sqlStage.initialize();
140142
return sqlStage;
141143
}
@@ -147,7 +149,7 @@ private SqlStage(
147149
NodeTaskMap nodeTaskMap,
148150
boolean summarizeTaskInfo,
149151
LocalExchangeBucketCountProvider bucketCountProvider,
150-
Map<PlanNodeId, ConnectorTableCredentials> tableCredentials)
152+
Supplier<Map<PlanNodeId, ConnectorTableCredentials>> tableCredentialsProvider)
151153
{
152154
this.session = requireNonNull(session, "session is null");
153155
this.stateMachine = stateMachine;
@@ -157,7 +159,7 @@ private SqlStage(
157159
this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null");
158160

159161
this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment());
160-
this.tableCredentials = ImmutableMap.copyOf(tableCredentials);
162+
this.tableCredentialsProvider = requireNonNull(tableCredentialsProvider, "tableCredentialsProvider is null");
161163
}
162164

163165
private static Map<PlanNodeId, ConnectorTableCredentials> extractTableCredentials(Session session, Metadata metadata, PlanFragment fragment)
@@ -308,7 +310,7 @@ public synchronized Optional<RemoteTask> createTask(
308310
node,
309311
speculative,
310312
fragment,
311-
tableCredentials,
313+
tableCredentialsProvider.get(),
312314
splits,
313315
outputBuffers,
314316
nodeTaskMap.createPartitionedSplitCountTracker(node, taskId),

core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,31 @@
1717
import com.google.common.collect.ImmutableListMultimap;
1818
import com.google.common.collect.ImmutableMap;
1919
import com.google.common.collect.ImmutableSet;
20+
import com.google.common.collect.Multimap;
2021
import com.google.common.util.concurrent.SettableFuture;
22+
import io.airlift.units.DataSize;
2123
import io.opentelemetry.api.trace.Span;
24+
import io.trino.Session;
25+
import io.trino.connector.CatalogHandle;
2226
import io.trino.cost.StatsAndCosts;
27+
import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker;
28+
import io.trino.execution.buffer.OutputBuffers;
2329
import io.trino.execution.buffer.PipelinedOutputBuffers;
2430
import io.trino.execution.scheduler.SplitSchedulerStats;
31+
import io.trino.metadata.AbstractMockMetadata;
32+
import io.trino.metadata.Metadata;
2533
import io.trino.metadata.Split;
2634
import io.trino.node.InternalNode;
2735
import io.trino.operator.RetryPolicy;
2836
import io.trino.spi.NodeVersion;
2937
import io.trino.spi.QueryId;
38+
import io.trino.spi.connector.ConnectorTableCredentials;
39+
import io.trino.spi.connector.ConnectorTableHandle;
3040
import io.trino.sql.planner.Partitioning;
3141
import io.trino.sql.planner.PartitioningScheme;
3242
import io.trino.sql.planner.PlanFragment;
3343
import io.trino.sql.planner.Symbol;
44+
import io.trino.sql.planner.plan.DynamicFilterId;
3445
import io.trino.sql.planner.plan.PlanFragmentId;
3546
import io.trino.sql.planner.plan.PlanNode;
3647
import io.trino.sql.planner.plan.PlanNodeId;
@@ -48,18 +59,25 @@
4859
import java.util.ArrayList;
4960
import java.util.Collections;
5061
import java.util.List;
62+
import java.util.Map;
5163
import java.util.Optional;
5264
import java.util.OptionalInt;
65+
import java.util.Set;
5366
import java.util.concurrent.CompletableFuture;
5467
import java.util.concurrent.CountDownLatch;
5568
import java.util.concurrent.ExecutorService;
5669
import java.util.concurrent.Future;
5770
import java.util.concurrent.ScheduledExecutorService;
71+
import java.util.concurrent.atomic.AtomicInteger;
72+
import java.util.concurrent.atomic.AtomicReference;
5873

5974
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
6075
import static io.airlift.tracing.Tracing.noopTracer;
6176
import static io.trino.SessionTestUtils.TEST_SESSION;
6277
import static io.trino.execution.SqlStage.createSqlStage;
78+
import static io.trino.execution.TaskTestUtils.PLAN_FRAGMENT;
79+
import static io.trino.execution.TaskTestUtils.SPLIT;
80+
import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID;
6381
import static io.trino.execution.buffer.PipelinedOutputBuffers.BufferType.ARBITRARY;
6482
import static io.trino.metadata.AbstractMockMetadata.dummyMetadata;
6583
import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
@@ -110,6 +128,131 @@ public void testFinalStageInfo()
110128
}
111129
}
112130

131+
@Test
132+
public void testTableCredentialsAreLazilyResolved()
133+
{
134+
AtomicInteger callCount = new AtomicInteger();
135+
Metadata metadata = new AbstractMockMetadata()
136+
{
137+
@Override
138+
public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle)
139+
{
140+
callCount.incrementAndGet();
141+
return Optional.empty();
142+
}
143+
};
144+
145+
SqlStage stage = createSqlStage(
146+
metadata,
147+
new StageId(new QueryId("query"), 0),
148+
createScanPlanFragment(),
149+
ImmutableMap.of(),
150+
new MockRemoteTaskFactory(executor, scheduledExecutor),
151+
TEST_SESSION,
152+
true,
153+
new NodeTaskMap(new FinalizerService()),
154+
executor,
155+
noopTracer(),
156+
Span.getInvalid(),
157+
new SplitSchedulerStats(),
158+
(_, _) -> OptionalInt.empty());
159+
160+
// Extraction does not run at stage construction
161+
assertThat(callCount).hasValue(0);
162+
163+
// First createTask triggers extraction
164+
createTestTask(stage, 0);
165+
assertThat(callCount).hasValue(1);
166+
167+
// Subsequent createTasks reuse the memoized result
168+
createTestTask(stage, 1);
169+
createTestTask(stage, 2);
170+
assertThat(callCount).hasValue(1);
171+
172+
stage.finish();
173+
}
174+
175+
@Test
176+
public void testTableCredentialsSeeLateUpdate()
177+
{
178+
AtomicReference<Optional<ConnectorTableCredentials>> response = new AtomicReference<>(Optional.empty());
179+
Metadata metadata = new AbstractMockMetadata()
180+
{
181+
@Override
182+
public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle)
183+
{
184+
return response.get();
185+
}
186+
};
187+
188+
AtomicReference<Map<PlanNodeId, ConnectorTableCredentials>> capturedCredentials = new AtomicReference<>();
189+
RemoteTaskFactory capturingFactory = new RemoteTaskFactory() {
190+
private final RemoteTaskFactory delegate = new MockRemoteTaskFactory(executor, scheduledExecutor);
191+
192+
@Override
193+
public RemoteTask createRemoteTask(
194+
Session session,
195+
Span stageSpan,
196+
TaskId taskId,
197+
InternalNode node,
198+
boolean speculative,
199+
PlanFragment fragment,
200+
Map<PlanNodeId, ConnectorTableCredentials> tableCredentials,
201+
Multimap<PlanNodeId, Split> initialSplits,
202+
OutputBuffers outputBuffers,
203+
PartitionedSplitCountTracker partitionedSplitCountTracker,
204+
Set<DynamicFilterId> outboundDynamicFilterIds,
205+
Optional<DataSize> estimatedMemory,
206+
boolean summarizeTaskInfo)
207+
{
208+
capturedCredentials.set(tableCredentials);
209+
return delegate.createRemoteTask(session, stageSpan, taskId, node, speculative, fragment, tableCredentials, initialSplits, outputBuffers, partitionedSplitCountTracker, outboundDynamicFilterIds, estimatedMemory, summarizeTaskInfo);
210+
}
211+
};
212+
213+
SqlStage stage = createSqlStage(
214+
metadata,
215+
new StageId(new QueryId("query"), 0),
216+
createScanPlanFragment(),
217+
ImmutableMap.of(),
218+
capturingFactory,
219+
TEST_SESSION,
220+
true,
221+
new NodeTaskMap(new FinalizerService()),
222+
executor,
223+
noopTracer(),
224+
Span.getInvalid(),
225+
new SplitSchedulerStats(),
226+
(_, _) -> OptionalInt.empty());
227+
228+
// Simulate a split source populating credentials after stage construction but before first task creation.
229+
ConnectorTableCredentials credentials = new ConnectorTableCredentials() {};
230+
response.set(Optional.of(credentials));
231+
232+
createTestTask(stage, 0);
233+
234+
assertThat(capturedCredentials.get())
235+
.containsExactly(Map.entry(TABLE_SCAN_NODE_ID, credentials));
236+
237+
stage.finish();
238+
}
239+
240+
private void createTestTask(SqlStage stage, int partition)
241+
{
242+
InternalNode node = new InternalNode("node" + partition, URI.create("http://node/" + partition), NodeVersion.UNKNOWN, false);
243+
stage.createTask(
244+
node,
245+
partition,
246+
0,
247+
Optional.empty(),
248+
OptionalInt.empty(),
249+
PipelinedOutputBuffers.createInitial(ARBITRARY),
250+
ImmutableListMultimap.of(TABLE_SCAN_NODE_ID, SPLIT.split()),
251+
ImmutableSet.of(),
252+
Optional.empty(),
253+
false);
254+
}
255+
113256
private void testFinalStageInfoInternal()
114257
throws Exception
115258
{
@@ -255,4 +398,9 @@ private static PlanFragment createExchangePlanFragment()
255398
ImmutableMap.of(),
256399
Optional.empty());
257400
}
401+
402+
private static PlanFragment createScanPlanFragment()
403+
{
404+
return PLAN_FRAGMENT.withOutputPartitioning(Optional.empty(), OptionalInt.empty());
405+
}
258406
}

0 commit comments

Comments
 (0)