|
17 | 17 | import com.google.common.collect.ImmutableListMultimap; |
18 | 18 | import com.google.common.collect.ImmutableMap; |
19 | 19 | import com.google.common.collect.ImmutableSet; |
| 20 | +import com.google.common.collect.Multimap; |
20 | 21 | import com.google.common.util.concurrent.SettableFuture; |
| 22 | +import io.airlift.units.DataSize; |
21 | 23 | import io.opentelemetry.api.trace.Span; |
| 24 | +import io.trino.Session; |
| 25 | +import io.trino.connector.CatalogHandle; |
22 | 26 | import io.trino.cost.StatsAndCosts; |
| 27 | +import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; |
| 28 | +import io.trino.execution.buffer.OutputBuffers; |
23 | 29 | import io.trino.execution.buffer.PipelinedOutputBuffers; |
24 | 30 | import io.trino.execution.scheduler.SplitSchedulerStats; |
| 31 | +import io.trino.metadata.AbstractMockMetadata; |
| 32 | +import io.trino.metadata.Metadata; |
25 | 33 | import io.trino.metadata.Split; |
26 | 34 | import io.trino.node.InternalNode; |
27 | 35 | import io.trino.operator.RetryPolicy; |
28 | 36 | import io.trino.spi.NodeVersion; |
29 | 37 | import io.trino.spi.QueryId; |
| 38 | +import io.trino.spi.connector.ConnectorTableCredentials; |
| 39 | +import io.trino.spi.connector.ConnectorTableHandle; |
30 | 40 | import io.trino.sql.planner.Partitioning; |
31 | 41 | import io.trino.sql.planner.PartitioningScheme; |
32 | 42 | import io.trino.sql.planner.PlanFragment; |
33 | 43 | import io.trino.sql.planner.Symbol; |
| 44 | +import io.trino.sql.planner.plan.DynamicFilterId; |
34 | 45 | import io.trino.sql.planner.plan.PlanFragmentId; |
35 | 46 | import io.trino.sql.planner.plan.PlanNode; |
36 | 47 | import io.trino.sql.planner.plan.PlanNodeId; |
|
48 | 59 | import java.util.ArrayList; |
49 | 60 | import java.util.Collections; |
50 | 61 | import java.util.List; |
| 62 | +import java.util.Map; |
51 | 63 | import java.util.Optional; |
52 | 64 | import java.util.OptionalInt; |
| 65 | +import java.util.Set; |
53 | 66 | import java.util.concurrent.CompletableFuture; |
54 | 67 | import java.util.concurrent.CountDownLatch; |
55 | 68 | import java.util.concurrent.ExecutorService; |
56 | 69 | import java.util.concurrent.Future; |
57 | 70 | import java.util.concurrent.ScheduledExecutorService; |
| 71 | +import java.util.concurrent.atomic.AtomicInteger; |
| 72 | +import java.util.concurrent.atomic.AtomicReference; |
58 | 73 |
|
59 | 74 | import static io.airlift.concurrent.Threads.daemonThreadsNamed; |
60 | 75 | import static io.airlift.tracing.Tracing.noopTracer; |
61 | 76 | import static io.trino.SessionTestUtils.TEST_SESSION; |
62 | 77 | import static io.trino.execution.SqlStage.createSqlStage; |
| 78 | +import static io.trino.execution.TaskTestUtils.PLAN_FRAGMENT; |
| 79 | +import static io.trino.execution.TaskTestUtils.SPLIT; |
| 80 | +import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; |
63 | 81 | import static io.trino.execution.buffer.PipelinedOutputBuffers.BufferType.ARBITRARY; |
64 | 82 | import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; |
65 | 83 | import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; |
@@ -110,6 +128,131 @@ public void testFinalStageInfo() |
110 | 128 | } |
111 | 129 | } |
112 | 130 |
|
| 131 | + @Test |
| 132 | + public void testTableCredentialsAreLazilyResolved() |
| 133 | + { |
| 134 | + AtomicInteger callCount = new AtomicInteger(); |
| 135 | + Metadata metadata = new AbstractMockMetadata() |
| 136 | + { |
| 137 | + @Override |
| 138 | + public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle) |
| 139 | + { |
| 140 | + callCount.incrementAndGet(); |
| 141 | + return Optional.empty(); |
| 142 | + } |
| 143 | + }; |
| 144 | + |
| 145 | + SqlStage stage = createSqlStage( |
| 146 | + metadata, |
| 147 | + new StageId(new QueryId("query"), 0), |
| 148 | + createScanPlanFragment(), |
| 149 | + ImmutableMap.of(), |
| 150 | + new MockRemoteTaskFactory(executor, scheduledExecutor), |
| 151 | + TEST_SESSION, |
| 152 | + true, |
| 153 | + new NodeTaskMap(new FinalizerService()), |
| 154 | + executor, |
| 155 | + noopTracer(), |
| 156 | + Span.getInvalid(), |
| 157 | + new SplitSchedulerStats(), |
| 158 | + (_, _) -> OptionalInt.empty()); |
| 159 | + |
| 160 | + // Extraction does not run at stage construction |
| 161 | + assertThat(callCount).hasValue(0); |
| 162 | + |
| 163 | + // First createTask triggers extraction |
| 164 | + createTestTask(stage, 0); |
| 165 | + assertThat(callCount).hasValue(1); |
| 166 | + |
| 167 | + // Subsequent createTasks reuse the memoized result |
| 168 | + createTestTask(stage, 1); |
| 169 | + createTestTask(stage, 2); |
| 170 | + assertThat(callCount).hasValue(1); |
| 171 | + |
| 172 | + stage.finish(); |
| 173 | + } |
| 174 | + |
| 175 | + @Test |
| 176 | + public void testTableCredentialsSeeLateUpdate() |
| 177 | + { |
| 178 | + AtomicReference<Optional<ConnectorTableCredentials>> response = new AtomicReference<>(Optional.empty()); |
| 179 | + Metadata metadata = new AbstractMockMetadata() |
| 180 | + { |
| 181 | + @Override |
| 182 | + public Optional<ConnectorTableCredentials> getTableCredentials(Session s, CatalogHandle catalogHandle, ConnectorTableHandle tableHandle) |
| 183 | + { |
| 184 | + return response.get(); |
| 185 | + } |
| 186 | + }; |
| 187 | + |
| 188 | + AtomicReference<Map<PlanNodeId, ConnectorTableCredentials>> capturedCredentials = new AtomicReference<>(); |
| 189 | + RemoteTaskFactory capturingFactory = new RemoteTaskFactory() { |
| 190 | + private final RemoteTaskFactory delegate = new MockRemoteTaskFactory(executor, scheduledExecutor); |
| 191 | + |
| 192 | + @Override |
| 193 | + public RemoteTask createRemoteTask( |
| 194 | + Session session, |
| 195 | + Span stageSpan, |
| 196 | + TaskId taskId, |
| 197 | + InternalNode node, |
| 198 | + boolean speculative, |
| 199 | + PlanFragment fragment, |
| 200 | + Map<PlanNodeId, ConnectorTableCredentials> tableCredentials, |
| 201 | + Multimap<PlanNodeId, Split> initialSplits, |
| 202 | + OutputBuffers outputBuffers, |
| 203 | + PartitionedSplitCountTracker partitionedSplitCountTracker, |
| 204 | + Set<DynamicFilterId> outboundDynamicFilterIds, |
| 205 | + Optional<DataSize> estimatedMemory, |
| 206 | + boolean summarizeTaskInfo) |
| 207 | + { |
| 208 | + capturedCredentials.set(tableCredentials); |
| 209 | + return delegate.createRemoteTask(session, stageSpan, taskId, node, speculative, fragment, tableCredentials, initialSplits, outputBuffers, partitionedSplitCountTracker, outboundDynamicFilterIds, estimatedMemory, summarizeTaskInfo); |
| 210 | + } |
| 211 | + }; |
| 212 | + |
| 213 | + SqlStage stage = createSqlStage( |
| 214 | + metadata, |
| 215 | + new StageId(new QueryId("query"), 0), |
| 216 | + createScanPlanFragment(), |
| 217 | + ImmutableMap.of(), |
| 218 | + capturingFactory, |
| 219 | + TEST_SESSION, |
| 220 | + true, |
| 221 | + new NodeTaskMap(new FinalizerService()), |
| 222 | + executor, |
| 223 | + noopTracer(), |
| 224 | + Span.getInvalid(), |
| 225 | + new SplitSchedulerStats(), |
| 226 | + (_, _) -> OptionalInt.empty()); |
| 227 | + |
| 228 | + // Simulate a split source populating credentials after stage construction but before first task creation. |
| 229 | + ConnectorTableCredentials credentials = new ConnectorTableCredentials() {}; |
| 230 | + response.set(Optional.of(credentials)); |
| 231 | + |
| 232 | + createTestTask(stage, 0); |
| 233 | + |
| 234 | + assertThat(capturedCredentials.get()) |
| 235 | + .containsExactly(Map.entry(TABLE_SCAN_NODE_ID, credentials)); |
| 236 | + |
| 237 | + stage.finish(); |
| 238 | + } |
| 239 | + |
| 240 | + private void createTestTask(SqlStage stage, int partition) |
| 241 | + { |
| 242 | + InternalNode node = new InternalNode("node" + partition, URI.create("http://node/" + partition), NodeVersion.UNKNOWN, false); |
| 243 | + stage.createTask( |
| 244 | + node, |
| 245 | + partition, |
| 246 | + 0, |
| 247 | + Optional.empty(), |
| 248 | + OptionalInt.empty(), |
| 249 | + PipelinedOutputBuffers.createInitial(ARBITRARY), |
| 250 | + ImmutableListMultimap.of(TABLE_SCAN_NODE_ID, SPLIT.split()), |
| 251 | + ImmutableSet.of(), |
| 252 | + Optional.empty(), |
| 253 | + false); |
| 254 | + } |
| 255 | + |
113 | 256 | private void testFinalStageInfoInternal() |
114 | 257 | throws Exception |
115 | 258 | { |
@@ -255,4 +398,9 @@ private static PlanFragment createExchangePlanFragment() |
255 | 398 | ImmutableMap.of(), |
256 | 399 | Optional.empty()); |
257 | 400 | } |
| 401 | + |
| 402 | + private static PlanFragment createScanPlanFragment() |
| 403 | + { |
| 404 | + return PLAN_FRAGMENT.withOutputPartitioning(Optional.empty(), OptionalInt.empty()); |
| 405 | + } |
258 | 406 | } |
0 commit comments