From 613d12c30cfc4bd61b1071f44199017479df7bd0 Mon Sep 17 00:00:00 2001 From: Florent Delannoy Date: Wed, 24 Sep 2025 14:49:58 +0200 Subject: [PATCH] Introduce ASOF and ASOF LEFT joins with single-inequality anchor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Adds ASOF/ASOF LEFT joins for nearest‑neighbor matching. - ON clause requires exactly one inequality comparing a right‑side expression to a left‑side expression; additional equi and same‑side predicates are allowed. - Implements analyzer/planner support and predicate pushdown semantics. - Updates documentation and adds focused analyzer, planner, and query tests. --- .../antlr4/io/trino/grammar/sql/SqlBase.g4 | 3 + .../io/trino/grammar/sql/TestSqlKeywords.java | 1 + .../java/io/trino/cost/JoinStatsRule.java | 4 +- .../io/trino/operator/JoinOperatorType.java | 4 +- .../java/io/trino/operator/PagesIndex.java | 7 +- .../operator/index/IndexSnapshotBuilder.java | 2 +- .../operator/join/HashBuilderOperator.java | 9 +- .../trino/operator/join/JoinHashSupplier.java | 4 +- .../operator/join/SortedPositionLinks.java | 14 +- .../join/unspilled/HashBuilderOperator.java | 12 +- .../trino/sql/analyzer/StatementAnalyzer.java | 98 ++++ .../java/io/trino/sql/gen/JoinCompiler.java | 5 +- .../planner/EffectivePredicateExtractor.java | 5 +- .../sql/planner/LocalExecutionPlanner.java | 33 +- .../io/trino/sql/planner/RelationPlanner.java | 2 + .../sql/planner/SortExpressionExtractor.java | 2 +- .../rule/AdaptiveReorderPartitionedJoin.java | 11 +- .../rule/DetermineJoinDistributionType.java | 24 +- ...alityFilterExpressionBelowJoinRuleSet.java | 3 +- .../iterative/rule/PushJoinIntoTableScan.java | 6 + .../iterative/rule/RemoveRedundantJoin.java | 4 +- .../planner/iterative/rule/ReorderJoins.java | 11 +- .../ReplaceJoinOverConstantWithProject.java | 10 +- .../rule/ReplaceRedundantJoinWithProject.java | 4 +- .../rule/ReplaceRedundantJoinWithSource.java | 7 +- .../optimizations/IndexJoinOptimizer.java | 5 + .../optimizations/PredicatePushDown.java | 261 +++++++-- .../optimizations/PropertyDerivations.java | 7 +- .../StreamPropertyDerivations.java | 5 +- .../sql/planner/plan/CorrelatedJoinNode.java | 2 + .../io/trino/sql/planner/plan/JoinNode.java | 11 +- .../io/trino/sql/planner/plan/JoinType.java | 4 +- .../io/trino/sql/planner/plan/UnnestNode.java | 1 + .../io/trino/operator/TestPagesIndex.java | 1 + .../BenchmarkHashBuildAndJoinOperators.java | 1 + .../io/trino/operator/join/JoinTestUtils.java | 1 + .../operator/join/TestHashJoinOperator.java | 3 + .../operator/join/TestPositionLinks.java | 3 +- .../BenchmarkHashBuildAndJoinOperators.java | 1 + .../join/unspilled/JoinTestUtils.java | 1 + .../unspilled/TestHashBuilderOperator.java | 1 + .../io/trino/sql/analyzer/TestAnalyzer.java | 91 +++ .../sql/planner/TestAdaptivePlanner.java | 4 +- .../trino/sql/planner/TestLogicalPlanner.java | 42 ++ .../sql/planner/TestPredicatePushdown.java | 530 ++++++++++++++++++ .../TestAdaptiveReorderPartitionedJoin.java | 52 ++ .../TestDetermineJoinDistributionType.java | 90 +++ ...alityFilterExpressionBelowJoinRuleSet.java | 89 +++ .../rule/TestPushJoinIntoTableScan.java | 48 ++ .../rule/TestRemoveRedundantJoin.java | 37 ++ ...estReplaceJoinOverConstantWithProject.java | 58 ++ .../TestReplaceRedundantJoinWithProject.java | 34 ++ .../TestReplaceRedundantJoinWithSource.java | 33 ++ .../java/io/trino/sql/query/TestJoin.java | 269 +++++++++ .../java/io/trino/sql/parser/AstBuilder.java | 11 +- .../src/main/java/io/trino/sql/tree/Join.java | 8 +- .../io/trino/sql/parser/TestSqlParser.java | 14 + .../parser/TestSqlParserErrorHandling.java | 4 +- docs/src/main/sphinx/language/reserved.md | 1 + docs/src/main/sphinx/sql/select.md | 56 ++ 60 files changed, 1958 insertions(+), 106 deletions(-) diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 02b83dbdb7fe..fb60ed14f316 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -376,6 +376,8 @@ joinType | LEFT OUTER? | RIGHT OUTER? | FULL OUTER? + | ASOF LEFT OUTER? + | ASOF ; joinCriteria @@ -1070,6 +1072,7 @@ AND: 'AND'; ANY: 'ANY'; ARRAY: 'ARRAY'; AS: 'AS'; +ASOF: 'ASOF'; ASC: 'ASC'; AT: 'AT'; AUTHORIZATION: 'AUTHORIZATION'; diff --git a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java index cb44176dc6ef..229844743ec6 100644 --- a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java +++ b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java @@ -37,6 +37,7 @@ public void test() "ANY", "ARRAY", "AS", + "ASOF", "ASC", "AT", "AUTHORIZATION", diff --git a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java index be4a580bcb2a..93d8ebd0ca91 100644 --- a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java @@ -84,8 +84,8 @@ protected Optional doCalculate(JoinNode node, Context con PlanNodeStatsEstimate crossJoinStats = crossJoinStats(node, leftStats, rightStats); return switch (node.getType()) { - case INNER -> Optional.of(computeInnerJoinStats(node, crossJoinStats, context.session())); - case LEFT -> Optional.of(computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, context.session())); + case INNER, ASOF -> Optional.of(computeInnerJoinStats(node, crossJoinStats, context.session())); + case LEFT, ASOF_LEFT -> Optional.of(computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, context.session())); case RIGHT -> Optional.of(computeRightJoinStats(node, leftStats, rightStats, crossJoinStats, context.session())); case FULL -> Optional.of(computeFullJoinStats(node, leftStats, rightStats, crossJoinStats, context.session())); }; diff --git a/core/trino-main/src/main/java/io/trino/operator/JoinOperatorType.java b/core/trino-main/src/main/java/io/trino/operator/JoinOperatorType.java index 3c2797bcf670..2e62342e3e32 100644 --- a/core/trino-main/src/main/java/io/trino/operator/JoinOperatorType.java +++ b/core/trino-main/src/main/java/io/trino/operator/JoinOperatorType.java @@ -31,8 +31,8 @@ public class JoinOperatorType public static JoinOperatorType ofJoinNodeType(JoinType joinNodeType, boolean outputSingleMatch, boolean waitForBuild) { return switch (joinNodeType) { - case INNER -> innerJoin(outputSingleMatch, waitForBuild); - case LEFT -> probeOuterJoin(outputSingleMatch); + case INNER, ASOF -> innerJoin(outputSingleMatch, waitForBuild); + case LEFT, ASOF_LEFT -> probeOuterJoin(outputSingleMatch); case RIGHT -> lookupOuterJoin(waitForBuild); case FULL -> fullOuterJoin(); }; diff --git a/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java b/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java index 6a449383b64d..309e758f0c01 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java @@ -459,7 +459,7 @@ public PagesIndexOrdering createPagesIndexComparator(List sortChannels, public Supplier createLookupSourceSupplier(Session session, List joinChannels) { - return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), ImmutableList.of()); + return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), false, ImmutableList.of()); } public PagesHashStrategy createPagesHashStrategy(List joinChannels) @@ -498,9 +498,10 @@ public LookupSourceSupplier createLookupSourceSupplier( List joinChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories) { - return createLookupSourceSupplier(session, joinChannels, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.empty(), defaultHashArraySizeSupplier()); + return createLookupSourceSupplier(session, joinChannels, filterFunctionFactory, sortChannel, sortedPositionLinksDescendingOrder, searchFunctionFactories, Optional.empty(), defaultHashArraySizeSupplier()); } public PagesSpatialIndexSupplier createPagesSpatialIndex( @@ -524,6 +525,7 @@ public LookupSourceSupplier createLookupSourceSupplier( List joinChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, Optional> outputChannels, HashArraySizeSupplier hashArraySizeSupplier) @@ -536,6 +538,7 @@ public LookupSourceSupplier createLookupSourceSupplier( channels, filterFunctionFactory, sortChannel, + sortedPositionLinksDescendingOrder, searchFunctionFactories, hashArraySizeSupplier); } diff --git a/core/trino-main/src/main/java/io/trino/operator/index/IndexSnapshotBuilder.java b/core/trino-main/src/main/java/io/trino/operator/index/IndexSnapshotBuilder.java index 8a2c166d861b..c416f924649a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/index/IndexSnapshotBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/index/IndexSnapshotBuilder.java @@ -126,7 +126,7 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco } pages.clear(); - LookupSource lookupSource = outputPagesIndex.createLookupSourceSupplier(session, keyOutputChannels, Optional.empty(), Optional.empty(), ImmutableList.of()).get(); + LookupSource lookupSource = outputPagesIndex.createLookupSourceSupplier(session, keyOutputChannels, Optional.empty(), Optional.empty(), false, ImmutableList.of()).get(); // Build a page containing the keys that produced no output rows, so in future requests can skip these keys verify(missingKeysPageBuilder.isEmpty()); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java index f5f2ff0c7280..f5a8ba9755b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java @@ -77,6 +77,7 @@ public static class HashBuilderOperatorFactory private final List hashChannels; private final Optional filterFunctionFactory; private final Optional sortChannel; + private final boolean sortedPositionLinksDescendingOrder; private final List searchFunctionFactories; private final PagesIndex.Factory pagesIndexFactory; @@ -97,6 +98,7 @@ public HashBuilderOperatorFactory( List hashChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory, @@ -115,6 +117,7 @@ public HashBuilderOperatorFactory( this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); this.sortChannel = sortChannel; + this.sortedPositionLinksDescendingOrder = sortedPositionLinksDescendingOrder; this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories); this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); this.spillEnabled = spillEnabled; @@ -141,6 +144,7 @@ public HashBuilderOperator createOperator(DriverContext driverContext) hashChannels, filterFunctionFactory, sortChannel, + sortedPositionLinksDescendingOrder, searchFunctionFactories, expectedPositions, pagesIndexFactory, @@ -215,6 +219,7 @@ public enum State private final List hashChannels; private final Optional filterFunctionFactory; private final Optional sortChannel; + private final boolean sortedPositionLinksDescendingOrder; private final List searchFunctionFactories; private final PagesIndex index; @@ -247,6 +252,7 @@ public HashBuilderOperator( List hashChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory, @@ -261,6 +267,7 @@ public HashBuilderOperator( this.partitionIndex = partitionIndex; this.filterFunctionFactory = filterFunctionFactory; this.sortChannel = sortChannel; + this.sortedPositionLinksDescendingOrder = sortedPositionLinksDescendingOrder; this.searchFunctionFactories = searchFunctionFactories; this.localUserMemoryContext = new CoarseGrainLocalMemoryContext(operatorContext.localUserMemoryContext(), memorySyncGranularity); this.localRevocableMemoryContext = new CoarseGrainLocalMemoryContext(operatorContext.localRevocableMemoryContext(), memorySyncGranularity); @@ -650,7 +657,7 @@ private void disposeUnspilledLookupSourceIfRequested() private LookupSourceSupplier buildLookupSource() { - LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.of(outputChannels), hashArraySizeSupplier); + LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, filterFunctionFactory, sortChannel, sortedPositionLinksDescendingOrder, searchFunctionFactories, Optional.of(outputChannels), hashArraySizeSupplier); checkState(lookupSourceSupplier == null, "lookupSourceSupplier is already set"); this.lookupSourceSupplier = partition; return partition; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java b/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java index 475cf7af43d7..f69e166124c5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java @@ -62,6 +62,7 @@ public JoinHashSupplier( List> channels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, HashArraySizeSupplier hashArraySizeSupplier, OptionalInt singleBigintJoinChannel) @@ -79,7 +80,8 @@ public JoinHashSupplier( positionLinksFactoryBuilder = SortedPositionLinks.builder( addresses.size(), pagesHashStrategy, - addresses); + addresses, + sortedPositionLinksDescendingOrder); } else { positionLinksFactoryBuilder = ArrayPositionLinks.builder(addresses.size()); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java b/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java index a16065c18666..4b5725a13a52 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java @@ -54,13 +54,15 @@ public static class FactoryBuilder private final PositionComparator comparator; private final PagesHashStrategy pagesHashStrategy; private final LongArrayList addresses; + private final int orderingSign; - public FactoryBuilder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses) + public FactoryBuilder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses, boolean descendingOrder) { this.size = size; this.comparator = new PositionComparator(pagesHashStrategy, addresses); this.pagesHashStrategy = pagesHashStrategy; this.addresses = addresses; + this.orderingSign = descendingOrder ? -1 : 1; positionLinks = new Int2ObjectOpenHashMap<>(); } @@ -80,8 +82,8 @@ public int link(int from, int to) return from; } - // make sure that from value is the smaller one - if (comparator.compare(from, to) > 0) { + // make sure that from value is the smaller (or larger depending on ordering) one + if (comparator.compare(from, to) * orderingSign > 0) { // _from_ is larger so, just add to current chain _to_ positionLinks.computeIfAbsent(to, key -> new IntArrayList()).add(from); return to; @@ -121,7 +123,7 @@ public Factory build() if (positions.length > 0) { // Use the positionsList array for the merge sort temporary work buffer to avoid an extra redundant // copy. This works because we know that initially it has the same values as the array being sorted - IntArrays.mergeSort(positions, 0, positions.length, comparator, positionsList.elements()); + IntArrays.mergeSort(positions, 0, positions.length, (left, right) -> comparator.compare(left, right) * orderingSign, positionsList.elements()); // add link from starting position to position links chain arrayPositionLinksFactoryBuilder.link(key, positions[0]); // add links for the sorted internal elements @@ -190,9 +192,9 @@ private static long sizeOfPositionLinks(int[][] sortedPositionLinks) return retainedSize; } - public static FactoryBuilder builder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses) + public static FactoryBuilder builder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses, boolean descendingOrder) { - return new FactoryBuilder(size, pagesHashStrategy, addresses); + return new FactoryBuilder(size, pagesHashStrategy, addresses, descendingOrder); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/HashBuilderOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/HashBuilderOperator.java index 02d1423a5c69..7924abedcac1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/HashBuilderOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/HashBuilderOperator.java @@ -58,6 +58,7 @@ public static class HashBuilderOperatorFactory private final List hashChannels; private final Optional filterFunctionFactory; private final Optional sortChannel; + private final boolean sortedPositionLinksDescendingOrder; private final List searchFunctionFactories; private final PagesIndex.Factory pagesIndexFactory; @@ -76,6 +77,7 @@ public HashBuilderOperatorFactory( List hashChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory, @@ -92,6 +94,7 @@ public HashBuilderOperatorFactory( this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); this.sortChannel = sortChannel; + this.sortedPositionLinksDescendingOrder = sortedPositionLinksDescendingOrder; this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories); this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); this.hashArraySizeSupplier = requireNonNull(hashArraySizeSupplier, "hashArraySizeSupplier is null"); @@ -116,6 +119,7 @@ public HashBuilderOperator createOperator(DriverContext driverContext) hashChannels, filterFunctionFactory, sortChannel, + sortedPositionLinksDescendingOrder, searchFunctionFactories, expectedPositions, pagesIndexFactory, @@ -164,6 +168,7 @@ public enum State private final List hashChannels; private final Optional filterFunctionFactory; private final Optional sortChannel; + private final boolean sortedPositionLinksDescendingOrder; private final List searchFunctionFactories; private final HashArraySizeSupplier hashArraySizeSupplier; @@ -182,12 +187,13 @@ public HashBuilderOperator( List hashChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory, HashArraySizeSupplier hashArraySizeSupplier) { - this(operatorContext, lookupSourceFactory, partitionIndex, outputChannels, hashChannels, filterFunctionFactory, sortChannel, searchFunctionFactories, expectedPositions, pagesIndexFactory, hashArraySizeSupplier, DEFAULT_GRANULARITY); + this(operatorContext, lookupSourceFactory, partitionIndex, outputChannels, hashChannels, filterFunctionFactory, sortChannel, sortedPositionLinksDescendingOrder, searchFunctionFactories, expectedPositions, pagesIndexFactory, hashArraySizeSupplier, DEFAULT_GRANULARITY); } @VisibleForTesting @@ -199,6 +205,7 @@ public HashBuilderOperator( List hashChannels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory, @@ -211,6 +218,7 @@ public HashBuilderOperator( this.partitionIndex = partitionIndex; this.filterFunctionFactory = filterFunctionFactory; this.sortChannel = sortChannel; + this.sortedPositionLinksDescendingOrder = sortedPositionLinksDescendingOrder; this.searchFunctionFactories = searchFunctionFactories; this.localUserMemoryContext = new CoarseGrainLocalMemoryContext(operatorContext.localUserMemoryContext(), memorySyncThreshold); @@ -345,7 +353,7 @@ private void disposeLookupSourceIfRequested() private LookupSourceSupplier buildLookupSource() { checkState(index != null, "index is null"); - LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.of(outputChannels), hashArraySizeSupplier); + LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, filterFunctionFactory, sortChannel, sortedPositionLinksDescendingOrder, searchFunctionFactories, Optional.of(outputChannels), hashArraySizeSupplier); checkState(lookupSourceSupplier == null, "lookupSourceSupplier is already set"); this.lookupSourceSupplier = partition; return partition; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 99b9dfc20984..e9b8bd905c41 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -130,11 +130,13 @@ import io.trino.sql.tree.Analyze; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.AutoGroupBy; +import io.trino.sql.tree.BetweenPredicate; import io.trino.sql.tree.Call; import io.trino.sql.tree.CallArgument; import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; +import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateMaterializedView; @@ -415,7 +417,13 @@ import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.tree.DereferenceExpression.getQualifiedName; +import static io.trino.sql.tree.Join.Type.ASOF; +import static io.trino.sql.tree.Join.Type.ASOF_LEFT; import static io.trino.sql.tree.Join.Type.FULL; import static io.trino.sql.tree.Join.Type.INNER; import static io.trino.sql.tree.Join.Type.LEFT; @@ -423,6 +431,7 @@ import static io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch.ONE; import static io.trino.sql.tree.SaveMode.IGNORE; import static io.trino.sql.tree.SaveMode.REPLACE; +import static io.trino.sql.util.AstUtils.extractConjuncts; import static io.trino.sql.util.AstUtils.preOrder; import static io.trino.type.UnknownType.UNKNOWN; import static io.trino.util.MoreLists.mappedCopy; @@ -3417,6 +3426,13 @@ protected Scope visitJoin(Join node, Optional scope) Scope left = process(node.getLeft(), scope); Scope right = process(node.getRight(), isLateralRelation(node.getRight()) ? Optional.of(left) : scope); + // ASOF joins: disallow UNNEST on either side (not supported at analyzer stage) + if (node.getType() == ASOF || node.getType() == ASOF_LEFT) { + if (isUnnestRelation(node.getLeft()) || isUnnestRelation(node.getRight())) { + throw semanticException(NOT_SUPPORTED, node, "ASOF JOIN involving UNNEST is not supported"); + } + } + if (isLateralRelation(node.getRight())) { if (node.getType() == RIGHT || node.getType() == FULL) { Stream leftScopeReferences = getReferencesToScope(node.getRight(), analysis, left); @@ -3457,6 +3473,9 @@ else if (node.getType() == FULL) { } if (criteria instanceof JoinUsing joinUsing) { + if (node.getType() == ASOF || node.getType() == ASOF_LEFT) { + throw semanticException(NOT_SUPPORTED, node, "ASOF JOIN with USING clause is not supported"); + } return analyzeJoinUsing(node, joinUsing.getColumns(), scope, left, right); } @@ -3483,6 +3502,15 @@ else if (node.getType() == FULL) { analysis.recordSubqueries(node, expressionAnalysis); analysis.setJoinCriteria(node, expression); + + // For ASOF joins, enforce exactly one inequality predicate (<, <=, >, >=) that references both sides + if (node.getType() == ASOF || node.getType() == ASOF_LEFT) { + if (countAsofInequalities(expression, left, right) != 1) { + throw semanticException(INVALID_ARGUMENTS, expression, "ASOF JOIN requires exactly one inequality predicate in ON clause"); + } + // Validate that no single side of an inequality mixes references from both left and right scopes. + validateAsofInequalityScopes(expression, left, right); + } } else { throw new UnsupportedOperationException("Unsupported join criteria: " + criteria.getClass().getName()); @@ -3491,6 +3519,76 @@ else if (node.getType() == FULL) { return output; } + private int countAsofInequalities(Expression expression, Scope leftScope, Scope rightScope) + { + return (int) extractConjuncts(expression).stream() + .flatMap(this::betweenToInequalities) + .filter(conjunct -> isAsofInequalityCandidate(conjunct, leftScope, rightScope)) + .count(); + } + + private void validateAsofInequalityScopes(Expression expression, Scope leftScope, Scope rightScope) + { + extractConjuncts(expression).stream() + .flatMap(this::betweenToInequalities) + .filter(conjunct -> isAsofInequalityCandidate(conjunct, leftScope, rightScope)) + .map(ComparisonExpression.class::cast) + .forEach(comparison -> validateNoMixedScope(comparison.getLeft(), comparison.getRight(), leftScope, rightScope)); + } + + private Stream betweenToInequalities(Expression expression) + { + if (expression instanceof BetweenPredicate between) { + return Stream.of( + new ComparisonExpression(expression.getLocation().orElseThrow(), GREATER_THAN_OR_EQUAL, between.getValue(), between.getMin()), + new ComparisonExpression(expression.getLocation().orElseThrow(), LESS_THAN_OR_EQUAL, between.getValue(), between.getMax())); + } + return Stream.of(expression); + } + + private void validateNoMixedScope(Expression left, Expression right, Scope leftScope, Scope rightScope) + { + // Determine scope usage using qualified name dependencies and relation type resolution + Set leftDependencies = NamesExtractor.extractNames(left, analysis.getColumnReferences()); + boolean leftSideReferencesLeft = hasReferences(leftDependencies, leftScope); + boolean leftSideReferencesRight = hasReferences(leftDependencies, rightScope); + + Set rightDependencies = NamesExtractor.extractNames(right, analysis.getColumnReferences()); + boolean rightSideReferencesLeft = hasReferences(rightDependencies, leftScope); + boolean rightSideReferencesRight = hasReferences(rightDependencies, rightScope); + + if (leftSideReferencesLeft && leftSideReferencesRight) { + throw semanticException(INVALID_ARGUMENTS, left, "ASOF inequality side mixes left and right references"); + } + if (rightSideReferencesLeft && rightSideReferencesRight) { + throw semanticException(INVALID_ARGUMENTS, right, "ASOF inequality side mixes left and right references"); + } + } + + private boolean isAsofInequalityCandidate(Expression expression, Scope leftScope, Scope rightScope) + { + if (!(expression instanceof ComparisonExpression comparison)) { + return false; + } + + if (comparison.getOperator() != LESS_THAN && + comparison.getOperator() != LESS_THAN_OR_EQUAL && + comparison.getOperator() != GREATER_THAN && + comparison.getOperator() != GREATER_THAN_OR_EQUAL) { + return false; + } + + Set dependencies = NamesExtractor.extractNames(expression, analysis.getColumnReferences()); + boolean hasLeftSideReferences = hasReferences(dependencies, leftScope); + boolean hasRightSideReferences = hasReferences(dependencies, rightScope); + return hasLeftSideReferences && hasRightSideReferences; + } + + private boolean hasReferences(Set dependencies, Scope scope) + { + return dependencies.stream().anyMatch(scope.getRelationType()::canResolve); + } + @Override protected Scope visitUpdate(Update update, Optional scope) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java index 53508d1dcc33..5098c81ef3cd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java @@ -1004,7 +1004,7 @@ public LookupSourceSupplierFactory(Class joinHas { this.pagesHashStrategyFactory = pagesHashStrategyFactory; try { - constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class, Optional.class, List.class, HashArraySizeSupplier.class, OptionalInt.class); + constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class, Optional.class, boolean.class, List.class, HashArraySizeSupplier.class, OptionalInt.class); } catch (NoSuchMethodException e) { throw new RuntimeException(e); @@ -1018,12 +1018,13 @@ public LookupSourceSupplier createLookupSourceSupplier( List> channels, Optional filterFunctionFactory, Optional sortChannel, + boolean sortedPositionLinksDescendingOrder, List searchFunctionFactories, HashArraySizeSupplier hashArraySizeSupplier) { PagesHashStrategy pagesHashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channels); try { - return constructor.newInstance(session, pagesHashStrategy, addresses, channels, filterFunctionFactory, sortChannel, searchFunctionFactories, hashArraySizeSupplier, singleBigintJoinChannel); + return constructor.newInstance(session, pagesHashStrategy, addresses, channels, filterFunctionFactory, sortChannel, sortedPositionLinksDescendingOrder, searchFunctionFactories, hashArraySizeSupplier, singleBigintJoinChannel); } catch (ReflectiveOperationException e) { throw new RuntimeException(e); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index 3df5380c757e..c607298bcbc0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -304,6 +304,7 @@ public Expression visitUnnest(UnnestNode node, Void context) return switch (node.getJoinType()) { case INNER, LEFT -> pullExpressionThroughSymbols(node.getSource().accept(this, context), node.getOutputSymbols()); case RIGHT, FULL -> TRUE; + case ASOF, ASOF_LEFT -> throw new IllegalStateException("ASOF joins are not supported by UNNEST"); }; } @@ -318,13 +319,13 @@ public Expression visitJoin(JoinNode node, Void context) .collect(toImmutableList()); return switch (node.getType()) { - case INNER -> pullExpressionThroughSymbols(combineConjuncts(ImmutableList.builder() + case INNER, ASOF -> pullExpressionThroughSymbols(combineConjuncts(ImmutableList.builder() .add(leftPredicate) .add(rightPredicate) .add(combineConjuncts(joinConjuncts)) .add(node.getFilter().orElse(TRUE)) .build()), node.getOutputSymbols()); - case LEFT -> combineConjuncts(ImmutableList.builder() + case LEFT, ASOF_LEFT -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 24c54f6376e5..c48ddb4c835b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -363,12 +363,15 @@ import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.trino.sql.planner.SortExpressionExtractor.extractSortExpression; +import static io.trino.sql.planner.SortExpressionExtractor.hasBuildSymbolReference; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; @@ -379,6 +382,8 @@ import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.FrameBoundType.CURRENT_ROW; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -2550,7 +2555,7 @@ public PhysicalOperation visitJoin(JoinNode node, LocalExecutionPlanContext cont List rightSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight); return switch (node.getType()) { - case INNER, LEFT, RIGHT, FULL -> + case INNER, LEFT, RIGHT, FULL, ASOF, ASOF_LEFT -> createLookupJoin(node, node.getLeft(), leftSymbols, node.getRight(), rightSymbols, localDynamicFilters, context); }; } @@ -2879,6 +2884,7 @@ private PhysicalOperation createLookupJoin( // Plan build boolean buildOuter = node.getType() == RIGHT || node.getType() == FULL; + boolean asofJoin = node.getType() == ASOF || node.getType() == ASOF_LEFT; boolean spillEnabled = isSpillEnabled(session) && node.isSpillable().orElseThrow(() -> new IllegalArgumentException("spillable not yet set")) && !buildOuter; @@ -2903,6 +2909,10 @@ private PhysicalOperation createLookupJoin( .collect(toImmutableSet()) .containsAll(node.getRightOutputSymbols()); + if (asofJoin) { + outputSingleMatch = true; + } + LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = buildNode.accept(this, buildContext); @@ -2917,8 +2927,9 @@ private PhysicalOperation createLookupJoin( probeSource.getLayout(), buildLayout)); + Set rightChildOutputSymbols = ImmutableSet.copyOf(node.getRight().getOutputSymbols()); Optional sortExpressionContext = node.getFilter() - .flatMap(filter -> extractSortExpression(ImmutableSet.copyOf(node.getRight().getOutputSymbols()), filter)); + .flatMap(filter -> extractSortExpression(rightChildOutputSymbols, filter)); Optional sortChannel = sortExpressionContext .map(SortExpressionContext::getSortExpression) @@ -2935,6 +2946,22 @@ private PhysicalOperation createLookupJoin( .collect(toImmutableList())) .orElse(ImmutableList.of()); + boolean sortedPositionLinksDescendingOrder = false; + if (asofJoin) { + int searchExpressions = sortExpressionContext.map(sort -> sort.getSearchExpressions().size()).orElse(0); + checkState(searchExpressions == 1, "ASOF JOIN requires exactly one inequality predicate"); + + // SortedPositionLinks provide match candidates in ascending order. To get the closest match for ASOF JOIN + // with "> build_symbol" or ">= build_symbol" operator we need to traverse the matches in reverse order. + Comparison comparison = (Comparison) getOnlyElement(sortExpressionContext.orElseThrow().getSearchExpressions()); + boolean buildOnRight = hasBuildSymbolReference(rightChildOutputSymbols, comparison.right()); + boolean buildOnLeft = hasBuildSymbolReference(rightChildOutputSymbols, comparison.left()); + checkState(buildOnLeft != buildOnRight, "Invalid ASOF inequality expression %s", comparison); + sortedPositionLinksDescendingOrder = + buildOnRight && (comparison.operator() == GREATER_THAN || comparison.operator() == GREATER_THAN_OR_EQUAL) || + buildOnLeft && (comparison.operator() == LESS_THAN || comparison.operator() == LESS_THAN_OR_EQUAL); + } + List buildOutputTypes = buildOutputChannels.stream() .map(buildSource.getTypes()::get) .collect(toImmutableList()); @@ -2980,6 +3007,7 @@ private PhysicalOperation createLookupJoin( buildChannels, filterFunctionFactory, sortChannel, + sortedPositionLinksDescendingOrder, searchFunctionFactories, 10_000, pagesIndexFactory, @@ -3031,6 +3059,7 @@ private PhysicalOperation createLookupJoin( buildChannels, filterFunctionFactory, sortChannel, + sortedPositionLinksDescendingOrder, searchFunctionFactories, 10_000, pagesIndexFactory, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 599913025f57..b8802457d7a6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -266,6 +266,8 @@ public static JoinType mapJoinType(Join.Type joinType) case LEFT -> JoinType.LEFT; case RIGHT -> JoinType.RIGHT; case FULL -> JoinType.FULL; + case ASOF -> JoinType.ASOF; + case ASOF_LEFT -> JoinType.ASOF_LEFT; }; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java index 87f559fe9c2c..cddd74c6b679 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java @@ -143,7 +143,7 @@ private static Optional asBuildSymbolReference(Set buildLayou return Optional.empty(); } - private static boolean hasBuildSymbolReference(Set buildSymbols, Expression expression) + public static boolean hasBuildSymbolReference(Set buildSymbols, Expression expression) { return extractAll(expression).stream().anyMatch(buildSymbols::contains); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin.java index a2416101862f..cd65c8680c27 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin.java @@ -142,9 +142,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) return Result.empty(); } + Optional flippedJoinNode = joinNode.flipChildren(); + if (flippedJoinNode.isEmpty()) { + return Result.empty(); + } + boolean flipJoin = flipJoinBasedOnStats(joinNode, context); if (flipJoin) { - return Result.ofPlanNode(flipJoinAndFixLocalExchanges(joinNode, localExchangeNode.getId(), metadata, context)); + return Result.ofPlanNode(flipJoinAndFixLocalExchanges(flippedJoinNode.get(), localExchangeNode.getId(), metadata, context)); } return Result.empty(); } @@ -156,13 +161,11 @@ private static boolean isBuildSideLocalExchangeNode(ExchangeNode exchangeNode, S } private static JoinNode flipJoinAndFixLocalExchanges( - JoinNode joinNode, + JoinNode flippedJoinNode, PlanNodeId buildSideLocalExchangeId, Metadata metadata, Context context) { - JoinNode flippedJoinNode = joinNode.flipChildren(); - // Fix local exchange on probe side BuildToProbeLocalExchangeRewriter buildToProbeLocalExchangeRewriter = new BuildToProbeLocalExchangeRewriter( buildSideLocalExchangeId, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java index b3cad3c0eed9..0544aa9d19e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -31,6 +31,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static io.trino.SystemSessionProperties.getJoinDistributionType; import static io.trino.SystemSessionProperties.getJoinMaxBroadcastTableSize; @@ -99,11 +100,13 @@ private PlanNode getCostBasedJoin(JoinNode joinNode, Context context) { List possibleJoinNodes = new ArrayList<>(); + Optional flippedJoin = joinNode.flipChildren(); + addJoinsWithDifferentDistributions(joinNode, possibleJoinNodes, context); - addJoinsWithDifferentDistributions(joinNode.flipChildren(), possibleJoinNodes, context); + flippedJoin.ifPresent(node -> addJoinsWithDifferentDistributions(node, possibleJoinNodes, context)); if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) { - return getSizeBasedJoin(joinNode, context); + return getSizeBasedJoin(joinNode, flippedJoin, context); } // Using Ordering to facilitate rule determinism @@ -111,7 +114,7 @@ private PlanNode getCostBasedJoin(JoinNode joinNode, Context context) return planNodeOrderings.min(possibleJoinNodes).getPlanNode(); } - private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context) + private JoinNode getSizeBasedJoin(JoinNode joinNode, Optional flippedJoin, Context context) { DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession()); @@ -122,10 +125,9 @@ private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context) } boolean isLeftSideSmall = getSourceTablesSizeInBytes(joinNode.getLeft(), context.getLookup(), context.getStatsProvider()) <= joinMaxBroadcastTableSize.toBytes(); - JoinNode flippedJoin = joinNode.flipChildren(); - if (isLeftSideSmall && !mustPartition(flippedJoin)) { + if (flippedJoin.isPresent() && isLeftSideSmall && !mustPartition(flippedJoin.get())) { // choose join left side with small source tables as replicated build side - return flippedJoin.withDistributionType(REPLICATED); + return flippedJoin.get().withDistributionType(REPLICATED); } if (isRightSideSmall) { @@ -133,9 +135,9 @@ private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context) return joinNode.withDistributionType(PARTITIONED); } - if (isLeftSideSmall) { + if (flippedJoin.isPresent() && isLeftSideSmall) { // left side is small enough, but must be partitioned - return flippedJoin.withDistributionType(PARTITIONED); + return flippedJoin.get().withDistributionType(PARTITIONED); } // Flip join sides if one side is smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times. @@ -150,8 +152,8 @@ private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context) return joinNode.withDistributionType(PARTITIONED); } - if (leftOutputSizeInBytes * SIZE_DIFFERENCE_THRESHOLD < rightOutputSizeInBytes && !mustReplicate(flippedJoin, context)) { - return flippedJoin.withDistributionType(PARTITIONED); + if (flippedJoin.isPresent() && leftOutputSizeInBytes * SIZE_DIFFERENCE_THRESHOLD < rightOutputSizeInBytes && !mustReplicate(flippedJoin.get(), context)) { + return flippedJoin.get().withDistributionType(PARTITIONED); } // neither side is small enough, choose syntactic join order @@ -186,7 +188,7 @@ private static boolean mustPartition(JoinNode joinNode) { JoinType type = joinNode.getType(); // With REPLICATED, the unmatched rows from right-side would be duplicated. - return type == RIGHT || type == FULL; + return type == RIGHT || type == FULL || type == JoinType.ASOF || type == JoinType.ASOF_LEFT; } private static boolean mustReplicate(JoinNode joinNode, Context context) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java index 7761531150eb..af47cb0e54f3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java @@ -48,6 +48,7 @@ import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.Rule.Context; import static io.trino.sql.planner.iterative.Rule.Result; +import static io.trino.sql.planner.plan.JoinType.ASOF; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.Patterns.filter; import static io.trino.sql.planner.plan.Patterns.join; @@ -111,7 +112,7 @@ private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode Expression parentFilterPredicate = filterNode.map(FilterNode::getPredicate).orElse(TRUE); Map> parentFilterCandidates; - if (joinNode.getType() == INNER) { + if (joinNode.getType() == INNER || joinNode.getType() == ASOF) { parentFilterCandidates = extractPushDownCandidates(joinNodeContext, parentFilterPredicate); } else { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index 37f0054ae67d..0104a6a6c3eb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -99,6 +99,11 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) return Result.empty(); } + // Do not attempt to push ASOF joins into connectors + if (joinNode.getType() == io.trino.sql.planner.plan.JoinType.ASOF || joinNode.getType() == io.trino.sql.planner.plan.JoinType.ASOF_LEFT) { + return Result.empty(); + } + TableScanNode left = captures.get(LEFT_TABLE_SCAN); TableScanNode right = captures.get(RIGHT_TABLE_SCAN); @@ -255,6 +260,7 @@ private JoinType getJoinType(JoinNode joinNode) case LEFT -> JoinType.LEFT_OUTER; case RIGHT -> JoinType.RIGHT_OUTER; case FULL -> JoinType.FULL_OUTER; + case ASOF, ASOF_LEFT -> throw new UnsupportedOperationException("ASOF join pushdown is not supported"); }; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantJoin.java index 75a2e25b8f15..382add463bb7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantJoin.java @@ -50,8 +50,8 @@ private boolean canRemoveJoin(JoinNode joinNode, Lookup lookup) PlanNode left = joinNode.getLeft(); PlanNode right = joinNode.getRight(); return switch (joinNode.getType()) { - case INNER -> isEmpty(left, lookup) || isEmpty(right, lookup); - case LEFT -> isEmpty(left, lookup); + case INNER, ASOF -> isEmpty(left, lookup) || isEmpty(right, lookup); + case LEFT, ASOF_LEFT -> isEmpty(left, lookup); case RIGHT -> isEmpty(right, lookup); case FULL -> isEmpty(left, lookup) && isEmpty(right, lookup); }; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java index d6e37e273c73..ae107a58d50c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java @@ -430,7 +430,8 @@ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED)); } if (isAtMostScalar(joinNode.getLeft(), lookup)) { - return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED)); + // joinNode can be flipped because it is an inner join + return createJoinEnumerationResult(joinNode.flipChildren().orElseThrow().withDistributionType(REPLICATED)); } List possibleJoinNodes = getPossibleJoinNodes(joinNode, getJoinDistributionType(session)); verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty"); @@ -465,10 +466,10 @@ private List getPossibleJoinNodes(JoinNode joinNode, Dist private List getPossibleJoinNodes(JoinNode joinNode, DistributionType distributionType, Predicate isAllowed) { - List nodes = ImmutableList.of( - joinNode.withDistributionType(distributionType), - joinNode.flipChildren().withDistributionType(distributionType)); - return nodes.stream().filter(isAllowed).map(this::createJoinEnumerationResult).collect(toImmutableList()); + ImmutableList.Builder nodes = ImmutableList.builder(); + nodes.add(joinNode.withDistributionType(distributionType)); + joinNode.flipChildren().ifPresent(flipped -> nodes.add(flipped.withDistributionType(distributionType))); + return nodes.build().stream().filter(isAllowed).map(this::createJoinEnumerationResult).collect(toImmutableList()); } private JoinEnumerationResult createJoinEnumerationResult(JoinNode joinNode) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java index 702a06386e45..8d0bbfc6d429 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java @@ -35,6 +35,8 @@ import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; +import static io.trino.sql.planner.plan.JoinType.INNER; +import static io.trino.sql.planner.plan.JoinType.LEFT; import static io.trino.sql.planner.plan.Patterns.join; /** @@ -105,8 +107,8 @@ public Result apply(JoinNode node, Captures captures, Context context) boolean canInlineRightSource = canInlineJoinSource(right); return switch (node.getType()) { - case INNER -> { - if (canInlineLeftSource) { + case INNER, ASOF -> { + if (canInlineLeftSource && node.getType() == INNER) { yield Result.ofPlanNode(appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator())); } if (canInlineRightSource) { @@ -114,8 +116,8 @@ public Result apply(JoinNode node, Captures captures, Context context) } yield Result.empty(); } - case LEFT -> { - if (canInlineLeftSource && rightCardinality.isAtLeastScalar()) { + case LEFT, ASOF_LEFT -> { + if (canInlineLeftSource && rightCardinality.isAtLeastScalar() && node.getType() == LEFT) { yield Result.ofPlanNode(appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator())); } if (canInlineRightSource) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java index 5335060f6d4c..b15ef4098cb0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java @@ -55,8 +55,8 @@ public Result apply(JoinNode node, Captures captures, Context context) PlanNode right = node.getRight(); return switch (node.getType()) { - case INNER -> Result.empty(); - case LEFT -> !isEmpty(left, lookup) && isEmpty(right, lookup) ? + case INNER, ASOF -> Result.empty(); + case LEFT, ASOF_LEFT -> !isEmpty(left, lookup) && isEmpty(right, lookup) ? Result.ofPlanNode(appendNulls( left, node.getLeftOutputSymbols(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithSource.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithSource.java index fd8d1d803475..31c2756a1935 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithSource.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithSource.java @@ -27,6 +27,7 @@ import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; +import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.Patterns.join; /** @@ -70,10 +71,10 @@ public Result apply(JoinNode node, Captures captures, Context context) boolean rightSourceScalarWithNoOutputs = node.getRight().getOutputSymbols().isEmpty() && rightCardinality.isScalar(); return switch (node.getType()) { - case INNER -> { + case INNER, ASOF -> { PlanNode source; List sourceOutputs; - if (leftSourceScalarWithNoOutputs) { + if (leftSourceScalarWithNoOutputs && node.getType() == INNER) { source = node.getRight(); sourceOutputs = node.getRightOutputSymbols(); } @@ -89,7 +90,7 @@ else if (rightSourceScalarWithNoOutputs) { } yield Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), source, ImmutableSet.copyOf(sourceOutputs)).orElse(source)); } - case LEFT -> rightSourceScalarWithNoOutputs ? + case LEFT, ASOF_LEFT -> rightSourceScalarWithNoOutputs ? Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getLeft(), ImmutableSet.copyOf(node.getLeftOutputSymbols())) .orElse(node.getLeft())) : Result.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index 4f81a4489a33..e332ce2c06ba 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -177,6 +177,11 @@ else if (leftIndexCandidate.isPresent()) { case FULL: break; + case ASOF: + case ASOF_LEFT: + // Do not attempt index join for ASOF joins + break; + default: throw new IllegalArgumentException("Unknown type: " + node.getType()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 4fd2c1e2a56e..5ed9a6bab61d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -105,6 +105,8 @@ import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.unwrapCasts; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -434,6 +436,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) switch (node.getType()) { case INNER -> { InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( + node.getType(), inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, @@ -445,8 +448,9 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); } - case LEFT -> { + case LEFT, ASOF_LEFT -> { OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin( + node.getType(), inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, @@ -460,6 +464,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } case RIGHT -> { OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin( + node.getType(), inheritedPredicate, rightEffectivePredicate, leftEffectivePredicate, @@ -477,6 +482,22 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) postJoinPredicate = inheritedPredicate; newJoinPredicate = joinPredicate; } + + case ASOF -> { + AsofJoinPushDownResult asofJoinPushDownResult = processAsofJoin( + node.getType(), + inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + node.getLeft().getOutputSymbols(), + node.getRight().getOutputSymbols()); + leftPredicate = asofJoinPushDownResult.getLeftJoinPredicate(); + rightPredicate = asofJoinPushDownResult.getRightJoinPredicate(); + postJoinPredicate = asofJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = asofJoinPushDownResult.getJoinPredicate(); + } + default -> throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } @@ -601,7 +622,7 @@ private DynamicFiltersResult createDynamicFilters( Session session, PlanNodeIdAllocator idAllocator) { - if ((node.getType() != INNER && node.getType() != RIGHT) || !isEnableDynamicFiltering(session) || !dynamicFiltering) { + if ((node.getType() != INNER && node.getType() != RIGHT && node.getType() != ASOF) || !isEnableDynamicFiltering(session) || !dynamicFiltering) { return new DynamicFiltersResult(ImmutableMap.of(), ImmutableList.of()); } @@ -751,6 +772,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext { InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( + INNER, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, @@ -764,6 +786,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext { OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin( + LEFT, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, @@ -831,6 +854,7 @@ private Symbol symbolForExpression(Expression expression) } private OuterJoinPushDownResult processLimitedOuterJoin( + JoinType joinType, Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, @@ -845,6 +869,7 @@ private OuterJoinPushDownResult processLimitedOuterJoin( ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder postJoinConjuncts = ImmutableList.builder(); ImmutableList.Builder joinConjuncts = ImmutableList.builder(); + ImmutableList.Builder inheritedInnerJoinPredicate = ImmutableList.builder(); // Strip out non-deterministic conjuncts extractConjuncts(inheritedPredicate).stream() @@ -896,6 +921,7 @@ private OuterJoinPushDownResult processLimitedOuterJoin( // A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side Expression innerRewritten = potentialNullSymbolInference.rewrite(outerRewritten, innerScope); if (innerRewritten != null) { + inheritedInnerJoinPredicate.add(innerRewritten); innerPushdownConjuncts.add(innerRewritten); } } @@ -912,8 +938,9 @@ private OuterJoinPushDownResult processLimitedOuterJoin( // See if we can push down join predicates to the inner side EqualityInference.nonInferrableConjuncts(joinPredicate).forEach(conjunct -> { + // Do not push down ASOF inequality candidate predicates Expression innerRewritten = potentialNullSymbolInference.rewrite(conjunct, innerScope); - if (innerRewritten != null) { + if (!isAsofInequalityCandidate(conjunct, joinType, outerSymbols, innerSymbols) && innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } else { @@ -924,7 +951,10 @@ private OuterJoinPushDownResult processLimitedOuterJoin( return new OuterJoinPushDownResult(combineConjuncts(outerPushdownConjuncts.build()), combineConjuncts(innerPushdownConjuncts.build()), combineConjuncts(joinConjuncts.build()), - combineConjuncts(postJoinConjuncts.build())); + combineConjuncts(postJoinConjuncts.build()), + // outerPushdownConjuncts will only contain inherited predicates pushed to the left side + combineConjuncts(outerPushdownConjuncts.build()), + combineConjuncts(inheritedInnerJoinPredicate.build())); } private static class OuterJoinPushDownResult @@ -933,13 +963,17 @@ private static class OuterJoinPushDownResult private final Expression innerJoinPredicate; private final Expression joinPredicate; private final Expression postJoinPredicate; + private final Expression inheritedOuterJoinPredicate; + private final Expression inheritedInnerJoinPredicate; - private OuterJoinPushDownResult(Expression outerJoinPredicate, Expression innerJoinPredicate, Expression joinPredicate, Expression postJoinPredicate) + private OuterJoinPushDownResult(Expression outerJoinPredicate, Expression innerJoinPredicate, Expression joinPredicate, Expression postJoinPredicate, Expression inheritedOuterJoinPredicate, Expression inheritedInnerJoinPredicate) { this.outerJoinPredicate = outerJoinPredicate; this.innerJoinPredicate = innerJoinPredicate; this.joinPredicate = joinPredicate; this.postJoinPredicate = postJoinPredicate; + this.inheritedOuterJoinPredicate = inheritedOuterJoinPredicate; + this.inheritedInnerJoinPredicate = inheritedInnerJoinPredicate; } private Expression getOuterJoinPredicate() @@ -961,9 +995,119 @@ private Expression getPostJoinPredicate() { return postJoinPredicate; } + + public Expression getInheritedOuterJoinPredicate() + { + return inheritedOuterJoinPredicate; + } + + public Expression getInheritedInnerJoinPredicate() + { + return inheritedInnerJoinPredicate; + } + } + + private AsofJoinPushDownResult processAsofJoin( + JoinType joinType, + Expression inheritedPredicate, + Expression leftEffectivePredicate, + Expression rightEffectivePredicate, + Expression joinPredicate, + Collection leftSymbols, + Collection rightSymbols) + { + // ASOF inner join behaves like inner join in regard to effective predicates + InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( + joinType, + TRUE, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + leftSymbols, + rightSymbols); + + // ASOF inner join behaves like left outer join in regard to inherited predicates + OuterJoinPushDownResult leftOuterJoinPushDownResult; + Expression postJoinPredicate = TRUE; + if (innerJoinPushDownResult.isDoNotPush() && !allowUnsafePushdown) { + // keep the inherited conjuncts that may fail above the join to preserve evaluation order of unsafe expressions + List safeInheritedConjuncts = new ArrayList<>(); + List mayFail = new ArrayList<>(); + for (Expression conjunct : extractConjuncts(inheritedPredicate)) { + if (mayFail(plannerContext, conjunct)) { + mayFail.add(conjunct); + } + else { + safeInheritedConjuncts.add(conjunct); + } + } + leftOuterJoinPushDownResult = processLimitedOuterJoin( + joinType, + combineConjuncts(safeInheritedConjuncts), + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + leftSymbols, + rightSymbols); + postJoinPredicate = combineConjuncts(mayFail); + } + else { + leftOuterJoinPushDownResult = processLimitedOuterJoin( + joinType, + inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + leftSymbols, + rightSymbols); + } + + checkState(innerJoinPushDownResult.getPostJoinPredicate().equals(TRUE)); + return new AsofJoinPushDownResult( + combineConjuncts(innerJoinPushDownResult.getLeftPredicate(), leftOuterJoinPushDownResult.getInheritedOuterJoinPredicate()), + combineConjuncts(innerJoinPushDownResult.getRightPredicate(), leftOuterJoinPushDownResult.getInheritedInnerJoinPredicate()), + innerJoinPushDownResult.getJoinPredicate(), + combineConjuncts(leftOuterJoinPushDownResult.getPostJoinPredicate(), postJoinPredicate)); + } + + private static class AsofJoinPushDownResult + { + private final Expression leftJoinPredicate; + private final Expression rightJoinPredicate; + private final Expression joinPredicate; + private final Expression postJoinPredicate; + + private AsofJoinPushDownResult(Expression leftJoinPredicate, Expression rightJoinPredicate, Expression joinPredicate, Expression postJoinPredicate) + { + this.leftJoinPredicate = leftJoinPredicate; + this.rightJoinPredicate = rightJoinPredicate; + this.joinPredicate = joinPredicate; + this.postJoinPredicate = postJoinPredicate; + } + + private Expression getLeftJoinPredicate() + { + return leftJoinPredicate; + } + + private Expression getRightJoinPredicate() + { + return rightJoinPredicate; + } + + public Expression getJoinPredicate() + { + return joinPredicate; + } + + private Expression getPostJoinPredicate() + { + return postJoinPredicate; + } } private InnerJoinPushDownResult processInnerJoin( + JoinType joinType, Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, @@ -1060,19 +1204,26 @@ else if (isInferenceCandidate(conjunct)) { .addAll(nonDeterministic); residuals.forEach(conjunct -> { - Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope); - if (leftRewrittenConjunct != null) { - leftPushDownConjuncts.add(leftRewrittenConjunct); - } + // Do not push down ASOF inequality candidate predicates + if (!isAsofInequalityCandidate(conjunct, joinType, leftSymbols, rightSymbols)) { + Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope); + if (leftRewrittenConjunct != null) { + leftPushDownConjuncts.add(leftRewrittenConjunct); + } - Expression rightRewrittenConjunct = allInference.rewrite(conjunct, rightScope); - if (rightRewrittenConjunct != null) { - rightPushDownConjuncts.add(rightRewrittenConjunct); - } + Expression rightRewrittenConjunct = allInference.rewrite(conjunct, rightScope); + if (rightRewrittenConjunct != null) { + rightPushDownConjuncts.add(rightRewrittenConjunct); + } - // Drop predicate after join only if unable to push down to either side - if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { - joinConjuncts.add(allInference.rewrite(conjunct, Sets.union(leftScope, rightScope))); + // Drop predicate after join only if unable to push down to either side + if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { + joinConjuncts.add(allInference.rewrite(conjunct, Sets.union(leftScope, rightScope))); + } + } + else { + // keep ASOF candidate inequality predicate unsimplified + joinConjuncts.add(conjunct); } }); @@ -1083,18 +1234,26 @@ else if (isInferenceCandidate(conjunct)) { joinConjuncts.add(allInference.rewrite(conjunct, Sets.union(leftScope, rightScope))); } else { - Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope); - if (leftRewrittenConjunct != null) { - leftPushDownConjuncts.add(leftRewrittenConjunct); + // Do not push down ASOF inequality candidate predicates + if (!isAsofInequalityCandidate(conjunct, joinType, leftSymbols, rightSymbols)) { + Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope); + if (leftRewrittenConjunct != null) { + leftPushDownConjuncts.add(leftRewrittenConjunct); + } + + Expression rightRewrittenConjunct = allInference.rewrite(conjunct, rightScope); + if (rightRewrittenConjunct != null) { + rightPushDownConjuncts.add(rightRewrittenConjunct); + } + + if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { + joinConjuncts.add(allInference.rewrite(conjunct, Sets.union(leftScope, rightScope))); + doNotPush = true; // we can't push any of the remaining conjuncts + } } - - Expression rightRewrittenConjunct = allInference.rewrite(conjunct, rightScope); - if (rightRewrittenConjunct != null) { - rightPushDownConjuncts.add(rightRewrittenConjunct); - } - - if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { - joinConjuncts.add(allInference.rewrite(conjunct, Sets.union(leftScope, rightScope))); + else { + // keep ASOF candidate inequality predicate unsimplified + joinConjuncts.add(conjunct); doNotPush = true; // we can't push any of the remaining conjuncts } } @@ -1104,7 +1263,30 @@ else if (isInferenceCandidate(conjunct)) { combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjuncts.build()), - TRUE); + TRUE, + doNotPush); + } + + private boolean isAsofInequalityCandidate(Expression expression, JoinType joinType, Collection leftSymbols, Collection rightSymbols) + { + if (joinType != ASOF && joinType != ASOF_LEFT) { + return false; + } + + if (!(expression instanceof Comparison comparison)) { + return false; + } + if (comparison.operator() != Comparison.Operator.LESS_THAN && + comparison.operator() != Comparison.Operator.LESS_THAN_OR_EQUAL && + comparison.operator() != Comparison.Operator.GREATER_THAN && + comparison.operator() != Comparison.Operator.GREATER_THAN_OR_EQUAL) { + return false; + } + + Set symbols = extractUnique(expression); + boolean hasLeftSideReferences = leftSymbols.stream().anyMatch(symbols::contains); + boolean hasRightSideReferences = rightSymbols.stream().anyMatch(symbols::contains); + return hasLeftSideReferences && hasRightSideReferences; } private static class InnerJoinPushDownResult @@ -1113,13 +1295,15 @@ private static class InnerJoinPushDownResult private final Expression rightPredicate; private final Expression joinPredicate; private final Expression postJoinPredicate; + private final boolean doNotPush; - private InnerJoinPushDownResult(Expression leftPredicate, Expression rightPredicate, Expression joinPredicate, Expression postJoinPredicate) + private InnerJoinPushDownResult(Expression leftPredicate, Expression rightPredicate, Expression joinPredicate, Expression postJoinPredicate, boolean doNotPush) { this.leftPredicate = leftPredicate; this.rightPredicate = rightPredicate; this.joinPredicate = joinPredicate; this.postJoinPredicate = postJoinPredicate; + this.doNotPush = doNotPush; } private Expression getLeftPredicate() @@ -1141,6 +1325,11 @@ private Expression getPostJoinPredicate() { return postJoinPredicate; } + + public boolean isDoNotPush() + { + return doNotPush; + } } private Expression extractJoinPredicate(JoinNode joinNode) @@ -1155,13 +1344,13 @@ private Expression extractJoinPredicate(JoinNode joinNode) private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { - checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); + checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL, ASOF, ASOF_LEFT).contains(node.getType()), "Unsupported join type: %s", node.getType()); - if (node.getType() == JoinType.INNER) { + if (node.getType() == INNER || node.getType() == ASOF) { return node; } - if (node.getType() == JoinType.FULL) { + if (node.getType() == FULL) { boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate); boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate); if (!canConvertToLeftJoin && !canConvertToRightJoin) { @@ -1199,13 +1388,15 @@ private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheri node.getReorderJoinStatsAndCost()); } - if (node.getType() == JoinType.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || - node.getType() == JoinType.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { + if (node.getType() == LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || + node.getType() == ASOF_LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || + node.getType() == RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { return node; } + return new JoinNode( node.getId(), - JoinType.INNER, + node.getType() == ASOF_LEFT ? ASOF : INNER, node.getLeft(), node.getRight(), node.getCriteria(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 8f64190b7512..aab95655a79f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -544,7 +544,7 @@ public ActualProperties visitJoin(JoinNode node, List inputPro boolean unordered = spillPossible(session, node.getType()); return switch (node.getType()) { - case INNER -> { + case INNER, ASOF -> { probeProperties = probeProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); buildProperties = buildProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); @@ -567,7 +567,7 @@ public ActualProperties visitJoin(JoinNode node, List inputPro .unordered(unordered) .build(); } - case LEFT -> ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) + case LEFT, ASOF_LEFT -> ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) .unordered(unordered) .build(); case RIGHT -> ActualProperties.builderFrom(buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) @@ -847,6 +847,7 @@ public ActualProperties visitUnnest(UnnestNode node, List inpu case RIGHT, FULL -> ActualProperties.builderFrom(translatedProperties) .local(ImmutableList.of()) .build(); + case ASOF, ASOF_LEFT -> throw new IllegalStateException("ASOF joins are not supported by UNNEST"); }; } @@ -926,7 +927,7 @@ static boolean spillPossible(Session session, JoinType joinType) return false; } return switch (joinType) { - case INNER, LEFT -> true; + case INNER, LEFT, ASOF, ASOF_LEFT -> true; // Even though join might not have "spillable" property set yet // it might still be set as spillable later on by AddLocalExchanges. case RIGHT, FULL -> false; // Currently there is no spill support for outer on the build side. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 6884394b9f7d..d64062cd15d9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -236,10 +236,10 @@ public StreamProperties visitJoin(JoinNode node, List inputPro boolean unordered = spillPossible(session, node); return switch (node.getType()) { - case INNER -> leftProperties + case INNER, ASOF -> leftProperties .translate(column -> PropertyDerivations.filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)) .unordered(unordered); - case LEFT -> leftProperties + case LEFT, ASOF_LEFT -> leftProperties .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)) .unordered(unordered); case RIGHT -> @@ -539,6 +539,7 @@ public StreamProperties visitUnnest(UnnestNode node, List inpu return switch (node.getJoinType()) { case INNER, LEFT -> translatedProperties; case RIGHT, FULL -> translatedProperties.unordered(true); + case ASOF, ASOF_LEFT -> throw new IllegalStateException("ASOF joins are not supported by UNNEST"); }; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java index 08be129f4b3b..ab8d183765c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java @@ -75,6 +75,8 @@ public CorrelatedJoinNode( this.input = input; this.subquery = subquery; this.correlation = ImmutableList.copyOf(correlation); + requireNonNull(type, "type is null"); + checkArgument(type != JoinType.ASOF && type != JoinType.ASOF_LEFT, "ASOF joins are not supported for correlated joins"); this.type = type; this.filter = filter; this.originSubquery = originSubquery; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java index 49b383198a03..05ab57d56e70 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java @@ -142,9 +142,13 @@ public JoinNode( } } - public JoinNode flipChildren() + public Optional flipChildren() { - return new JoinNode( + if (type == JoinType.ASOF || type == JoinType.ASOF_LEFT) { + return Optional.empty(); + } + + return Optional.of(new JoinNode( getId(), flipType(type), right, @@ -157,7 +161,7 @@ public JoinNode flipChildren() distributionType, spillable, ImmutableMap.of(), // dynamicFilters are invalid after flipping children - reorderJoinStatsAndCost); + reorderJoinStatsAndCost)); } private static JoinType flipType(JoinType type) @@ -167,6 +171,7 @@ private static JoinType flipType(JoinType type) case FULL -> FULL; case LEFT -> RIGHT; case RIGHT -> LEFT; + case ASOF, ASOF_LEFT -> throw new IllegalArgumentException("ASOF joins cannot be flipped"); }; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinType.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinType.java index 5796a9baba62..5ffc7d24848c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinType.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinType.java @@ -18,7 +18,9 @@ public enum JoinType INNER("InnerJoin"), LEFT("LeftJoin"), RIGHT("RightJoin"), - FULL("FullJoin"); + FULL("FullJoin"), + ASOF("AsofJoin"), + ASOF_LEFT("AsofLeftJoin"); private final String joinLabel; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java index cdf67cae1927..4adeb10d42da 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java @@ -56,6 +56,7 @@ public UnnestNode( this.mappings = ImmutableList.copyOf(mappings); this.ordinalitySymbol = requireNonNull(ordinalitySymbol, "ordinalitySymbol is null"); this.joinType = requireNonNull(joinType, "joinType is null"); + checkArgument(joinType != JoinType.ASOF && joinType != JoinType.ASOF_LEFT, "ASOF joins are not supported for UNNEST"); } @Override diff --git a/core/trino-main/src/test/java/io/trino/operator/TestPagesIndex.java b/core/trino-main/src/test/java/io/trino/operator/TestPagesIndex.java index ca1076fa0165..0721ec0f538d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestPagesIndex.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestPagesIndex.java @@ -150,6 +150,7 @@ public void testGetEstimatedLookupSourceSizeInBytes() ImmutableList.of(joinChannel), sortChannel.map(channel -> filterFunctionFactory), sortChannel, + false, ImmutableList.of(filterFunctionFactory), Optional.of(ImmutableList.of(0, 1)), defaultHashArraySizeSupplier()).get(); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java b/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java index fa079d5e6974..ddb6da9603ac 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java @@ -331,6 +331,7 @@ private static void buildHash(BuildContext buildContext, JoinBridgeManager 10"); + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t <= a.t AND b.t < 100"); + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t <= a.t AND 5 < 10"); + + // Inequality wrapped in OR is not a valid ASOF candidate + assertFails("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND (b.t <= a.t OR a.t < 0)") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessageContaining("ASOF JOIN requires exactly one inequality predicate in ON clause"); + + // ASOF LEFT with multiple inequalities where only one is a candidate + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF LEFT JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t <= a.t AND a.t >= 0"); + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF LEFT JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t <= a.t AND b.t > -100"); + + // BETWEEN present but only one ASOF candidate (the BETWEEN is same-side only) + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t <= a.t AND a.t BETWEEN 0 AND 10"); + analyze("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t BETWEEN a.t AND 10"); + assertFails("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) " + + "ON a.k = b.k AND b.t BETWEEN a.t + b.t AND 10") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessageContaining("ASOF inequality side mixes left and right references"); + + // USING clause with ASOF join should be rejected + assertFails("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF JOIN (VALUES (1, 1)) b(t, k) USING (k)") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:15: ASOF JOIN with USING clause is not supported"); + assertFails("SELECT 1 FROM (VALUES (1, 1)) a(t, k) ASOF LEFT JOIN (VALUES (1, 1)) b(t, k) USING (k)") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:15: ASOF JOIN with USING clause is not supported"); + } + @Test public void testNullTreatment() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java index 6aefcf1d1fb9..b06d9e47aa18 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java @@ -428,7 +428,9 @@ public Result apply(JoinNode node, Captures captures, Context context) return Result.empty(); } alreadyVisited.add(node.getId()); - return Result.ofPlanNode(node.flipChildren()); + return node.flipChildren() + .map(Result::ofPlanNode) + .orElse(Result.empty()); } } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 9c2c9b06e8d7..eceb0c933848 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -116,7 +116,9 @@ import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrExpressions.not; import static io.trino.sql.ir.Logical.Operator.AND; @@ -177,6 +179,8 @@ import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_FOLLOWING; import static io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; import static io.trino.sql.planner.plan.RowsPerMatch.WINDOW; @@ -566,6 +570,44 @@ public void testInnerInequalityJoinWithEquiJoinConjuncts() "O_ORDERKEY", "orderkey")))))))); } + @Test + public void testAsofJoinPlan() + { + assertPlan(""" + SELECT o.orderkey, l.comment + FROM orders o + ASOF JOIN lineitem l + ON o.orderkey = l.orderkey AND l.partkey <= o.orderkey + """, + output(join(ASOF, builder -> builder + .equiCriteria("O_ORDERKEY", "L_ORDERKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "L_PARTKEY"), new Reference(BIGINT, "O_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), EQUAL, "L_ORDERKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_PARTKEY"))) + .left(filter(TRUE, tableScan("orders", ImmutableMap.of( + "O_ORDERKEY", "orderkey")))) + .right(exchange(tableScan("lineitem", ImmutableMap.of( + "L_ORDERKEY", "orderkey", + "L_PARTKEY", "partkey", + "L_COMMENT", "comment"))))))); + assertPlan(""" + SELECT o.orderkey, l.comment + FROM orders o + ASOF LEFT JOIN lineitem l + ON o.orderkey = l.orderkey AND l.partkey <= o.orderkey + """, + output(join(ASOF_LEFT, builder -> builder + .equiCriteria("O_ORDERKEY", "L_ORDERKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "L_PARTKEY"), new Reference(BIGINT, "O_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O_ORDERKEY", "orderkey"))) + .right(exchange(tableScan("lineitem", ImmutableMap.of( + "L_ORDERKEY", "orderkey", + "L_PARTKEY", "partkey", + "L_COMMENT", "comment"))))))); + } + @Test public void testLeftConvertedToInnerInequalityJoinNoEquiJoinConjuncts() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java index e2c96b14b7f6..1b3cc8eb6f5c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java @@ -19,12 +19,16 @@ import io.trino.Session; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.type.Type; +import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; +import io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern; import io.trino.sql.planner.plan.ExchangeNode; import org.junit.jupiter.api.Test; @@ -34,15 +38,26 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.not; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.INNER; public class TestPredicatePushdown @@ -200,4 +215,519 @@ public void testNonStraddlingJoinExpression() anyTree( tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))))))); } + + @Test + public void testAsofJoinInheritedPredicatePushdown() + { + // inherited predicate referencing left side: propagate to left and via join equality to right + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o1.custkey > 10 + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o1.custkey > 10 + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + + // inherited predicate referencing both sides: remains as post-join filter + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o1.orderkey = o2.orderkey + """, + output(project(project( + filter( + new Comparison(EQUAL, new Reference(BIGINT, "O1_ORDERKEY"), new Reference(BIGINT, "O2_ORDERKEY")), + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter(TRUE, tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o1.orderkey = o2.orderkey OR o2.orderkey IS NULL + """, + output(project( + filter( + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "O1_ORDERKEY"), new Reference(BIGINT, "O2_ORDERKEY")), + new IsNull(new Reference(BIGINT, "O2_ORDERKEY")))), + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey"))) + .right(exchange(tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + + // inherited predicate referencing right side: remains as post-join filter + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o2.custkey = 10 + """, + output(project(project( + filter( + new Comparison(EQUAL, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter(TRUE, tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o2.custkey = 10 OR o2.custkey IS NULL + """, + output(project( + filter( + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + new IsNull(new Reference(BIGINT, "O2_CUSTKEY")))), + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey"))) + .right(exchange(tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + } + + @Test + public void testAsofJoinEffectivePredicateTransfers() + { + // left effective predicate transferrable to right side via join criteria + assertPlan( + """ + SELECT 1 + FROM (SELECT * FROM orders o1 WHERE o1.custkey > 10) o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + assertPlan( + """ + SELECT 1 + FROM (SELECT * FROM orders o1 WHERE o1.custkey > 10) o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + + // right effective predicate transferrable to left side via join criteria (ASOF inner only) + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN (SELECT * FROM orders o2 WHERE o2.custkey > 10) o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN (SELECT * FROM orders o2 WHERE o2.custkey > 10) o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey"))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + } + + @Test + public void testAsofJoinPredicatePushdown() + { + // left-side predicate specified in the ON clause should be pushed down to the left side (ASOF inner only) + Type commentType = createVarcharType(79); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o1.comment = 'F' + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(project(filter( + new Comparison(EQUAL, new Reference(commentType, "O1_COMMENT"), new Constant(commentType, Slices.utf8Slice("F"))), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey", + "O1_COMMENT", "comment"))))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))); + + // BETWEEN form: entire BETWEEN stays on join (no right-side pushdown) + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey BETWEEN o1.orderkey AND 4 + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Between(new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"), new Constant(BIGINT, 4L))) + .left(filter(TRUE, + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o1.comment = 'F' + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Logical(AND, ImmutableList.of( + new Comparison(EQUAL, new Reference(commentType, "O1_COMMENT"), new Constant(commentType, Slices.utf8Slice("F"))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey", + "O1_COMMENT", "comment"))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))); + + // right-side predicate specified in the ON clause should be pushed down to the right side + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o2.comment = 'F' + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter(TRUE, + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange(project( + filter(new Comparison(EQUAL, new Reference(commentType, "O2_COMMENT"), new Constant(commentType, Slices.utf8Slice("F"))), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey", + "O2_COMMENT", "comment")))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o2.comment = 'F' + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey"))) + .right(exchange(project( + filter(new Comparison(EQUAL, new Reference(commentType, "O2_COMMENT"), new Constant(commentType, Slices.utf8Slice("F"))), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey", + "O2_COMMENT", "comment")))))))))); + } + + @Test + public void testAsofJoinOnInequalityCandidateNotPushedRight() + { + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.custkey + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_CUSTKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter(TRUE, + tableScan("orders", ImmutableMap.of( + "O1_CUSTKEY", "custkey")))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.custkey + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_CUSTKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_CUSTKEY", "custkey"))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))); + } + + @Test + public void testAsofJoinRightOnlyInequalityInOnIsPushedRight() + { + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o2.orderkey < 4 + """, + output(project( + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(filter(TRUE, + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange( + filter(new Comparison(LESS_THAN, new Reference(BIGINT, "O2_ORDERKEY"), new Constant(BIGINT, 4L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey AND o2.orderkey < 4 + """, + output(project( + join(ASOF_LEFT, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .left(tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey"))) + .right(exchange( + filter(new Comparison(LESS_THAN, new Reference(BIGINT, "O2_ORDERKEY"), new Constant(BIGINT, 4L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))); + } + + @Test + public void testAsofJoinInheritedUnsafePredicateStaysAboveJoin() + { + // Inherited predicate with potential failure (CAST on comment) must remain above the ASOF join, + // while safe inherited predicates (o1.custkey > 10) are pushed down to both sides via equality. + Type commentType = createVarcharType(79); + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE CAST(o1.comment AS bigint) > 0 AND o1.custkey > 10 + """, + output(project(project( + filter( + new Comparison(GREATER_THAN, new Cast(new Reference(commentType, "O1_COMMENT"), BIGINT), new Constant(BIGINT, 0L)), + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O1_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey", + "O1_COMMENT", "comment")))) + .right(exchange(filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_CUSTKEY"), new Constant(BIGINT, 10L)), + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey"))))))))))); + } + + @Test + public void testAsofLeftJoinInheritedPredicateNormalizesJoinToAsof() + { + assertPlan( + """ + SELECT 1 + FROM orders o1 + ASOF LEFT JOIN orders o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.orderkey + WHERE o2.orderkey > 0 + """, + output(project(project( + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "O2_ORDERKEY"), new Constant(BIGINT, 0L)), + join(ASOF, builder -> builder + .equiCriteria("O1_CUSTKEY", "O2_CUSTKEY") + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O2_ORDERKEY"), new Reference(BIGINT, "O1_ORDERKEY"))) + .dynamicFilter(ImmutableList.of( + new DynamicFilterPattern(new Reference(BIGINT, "O1_CUSTKEY"), EQUAL, "O2_CUSTKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O1_ORDERKEY"), GREATER_THAN_OR_EQUAL, "O2_ORDERKEY"))) + .left(filter(TRUE, + tableScan("orders", ImmutableMap.of( + "O1_ORDERKEY", "orderkey", + "O1_CUSTKEY", "custkey")))) + .right(exchange( + tableScan("orders", ImmutableMap.of( + "O2_ORDERKEY", "orderkey", + "O2_CUSTKEY", "custkey")))))))))); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAdaptiveReorderPartitionedJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAdaptiveReorderPartitionedJoin.java index 0063170d7861..16be94fa05bc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAdaptiveReorderPartitionedJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAdaptiveReorderPartitionedJoin.java @@ -43,6 +43,7 @@ import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static io.trino.sql.planner.plan.JoinType.ASOF; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -221,6 +222,57 @@ public void testNoChangesWhenEitherBuildOrProbeSideIsNan() .doesNotFire(); } + @Test + public void testAsofJoinIsNotReordered() + { + // Configure stats that would cause reordering for INNER join, but must not for ASOF + RuleTester ruleTester = tester(); + String buildRemoteSourceId = "buildRemoteSourceId"; + String probeRemoteSourceId = "probeRemoteSourceId"; + ruleTester.assertThat(new AdaptiveReorderPartitionedJoin(ruleTester.getMetadata())) + .setSystemProperty(RETRY_POLICY, TASK.name()) + .overrideStats("buildRemoteSourceId", PlanNodeStatsEstimate.builder() + .setOutputRowCount(20_000_000_000L) + .build()) + .overrideStats("probeRemoteSourceId", PlanNodeStatsEstimate.builder() + .setOutputRowCount(10_000_000_000L) + .build()) + .on(p -> { + Symbol buildSymbol = p.symbol("buildSymbol", BIGINT); + Symbol symbol1 = p.symbol("symbol1", BIGINT); + Symbol probeSymbol = p.symbol("probeSymbol", BIGINT); + Symbol symbol2 = p.symbol("symbol2", BIGINT); + return p.join( + ASOF, + PARTITIONED, + // probe on the left + p.remoteSource( + new PlanNodeId(probeRemoteSourceId), + ImmutableList.of(new PlanFragmentId("1")), + ImmutableList.of(probeSymbol, symbol2), + Optional.empty(), + REPARTITION, + TASK), + // build on the right + p.exchange(builder -> builder + .addInputsSet(buildSymbol, symbol1) + .addSource(p.remoteSource( + new PlanNodeId(buildRemoteSourceId), + ImmutableList.of(new PlanFragmentId("2")), + ImmutableList.of(buildSymbol, symbol1), + Optional.empty(), + REPARTITION, + TASK)) + .fixedHashDistributionPartitioningScheme( + ImmutableList.of(buildSymbol, symbol1), + ImmutableList.of(buildSymbol)) + .type(REPARTITION) + .scope(LOCAL)), + new JoinNode.EquiJoinClause(probeSymbol, buildSymbol)); + }) + .doesNotFire(); + } + private RuleAssert assertWithPartialAgg(double buildRowCount, double probeRowCount) { RuleTester ruleTester = tester(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index dce9c6f20109..75b00c738110 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -152,6 +152,96 @@ public void testRepartitionRightOuter() testRepartitionRightOuter(JoinDistributionType.AUTOMATIC, RIGHT); } + @Test + public void testAsofJoinDoesNotFlip() + { + // Set up stats and cardinalities such that a cost-based rule would prefer flipping children. + PlanNodeStatsEstimate smallProbe = PlanNodeStatsEstimate.builder() + .setOutputRowCount(1) + .build(); + PlanNodeStatsEstimate largeBuild = PlanNodeStatsEstimate.builder() + .setOutputRowCount(10_000_000_000L) + .build(); + + assertDetermineJoinDistributionType() + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) + .overrideStats("probeValues", smallProbe) + .overrideStats("buildValues", largeBuild) + .on(p -> { + Symbol probeKey = p.symbol("probe_key", BIGINT); + Symbol probeTs = p.symbol("probe_ts", BIGINT); + Symbol buildKey = p.symbol("build_key", BIGINT); + Symbol buildTs = p.symbol("build_ts", BIGINT); + + return p.join( + JoinType.ASOF, + // small left/probe relation + p.values(new PlanNodeId("probeValues"), ImmutableList.of(probeKey, probeTs), + ImmutableList.of(ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 10L)))), + // large right/build relation (stats-controlled) + p.values(new PlanNodeId("buildValues"), ImmutableList.of(buildKey, buildTs), + ImmutableList.of(ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 9L)))), + new Comparison( + Comparison.Operator.LESS_THAN_OR_EQUAL, + new Reference(buildTs.type(), buildTs.name()), + new Reference(probeTs.type(), probeTs.name())), + new JoinNode.EquiJoinClause(probeKey, buildKey)); + }) + .matches(join(JoinType.ASOF, builder -> builder + .equiCriteria("probe_key", "build_key") + .filter(new Comparison( + Comparison.Operator.LESS_THAN_OR_EQUAL, + new Reference(BIGINT, "build_ts"), + new Reference(BIGINT, "probe_ts"))) + .left(values(ImmutableMap.of("probe_key", 0, "probe_ts", 1))) + .right(values(ImmutableMap.of("build_key", 0, "build_ts", 1))))); + } + + @Test + public void testAsofJoinAssignedPartitionedDistribution() + { + // Configure stats such that non-ASOF joins would prefer REPLICATED (small build), + // but ASOF must be PARTITIONED. + PlanNodeStatsEstimate probeStats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(10_000_000_000L) + .build(); + PlanNodeStatsEstimate buildStats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .build(); + + assertDetermineJoinDistributionType() + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) + .overrideStats("probeValues", probeStats) + .overrideStats("buildValues", buildStats) + .on(p -> { + Symbol probeKey = p.symbol("probe_key", BIGINT); + Symbol probeTs = p.symbol("probe_ts", BIGINT); + Symbol buildKey = p.symbol("build_key", BIGINT); + Symbol buildTs = p.symbol("build_ts", BIGINT); + + return p.join( + JoinType.ASOF, + p.values(new PlanNodeId("probeValues"), ImmutableList.of(probeKey, probeTs), + ImmutableList.of(ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 10L)))), + p.values(new PlanNodeId("buildValues"), ImmutableList.of(buildKey, buildTs), + ImmutableList.of(ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 9L)))), + new Comparison( + Comparison.Operator.LESS_THAN_OR_EQUAL, + new Reference(buildTs.type(), buildTs.name()), + new Reference(probeTs.type(), probeTs.name())), + new JoinNode.EquiJoinClause(probeKey, buildKey)); + }) + .matches(join(JoinType.ASOF, builder -> builder + .equiCriteria("probe_key", "build_key") + .distributionType(PARTITIONED) + .filter(new Comparison( + Comparison.Operator.LESS_THAN_OR_EQUAL, + new Reference(BIGINT, "build_ts"), + new Reference(BIGINT, "probe_ts"))) + .left(values(ImmutableMap.of("probe_key", 0, "probe_ts", 1))) + .right(values(ImmutableMap.of("build_key", 0, "build_ts", 1))))); + } + private void testRepartitionRightOuter(JoinDistributionType sessionDistributedJoin, JoinType joinType) { assertDetermineJoinDistributionType() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java index dac1195a41a3..bae3b1613d69 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java @@ -42,6 +42,8 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.INNER; public class TestPushInequalityFilterExpressionBelowJoinRuleSet @@ -96,6 +98,50 @@ public void testJoinFilterExpressionPushedDownToRightJoinSource() values("b"))))); } + @Test + public void testJoinFilterExpressionPushedDownToRightJoinSourceAsof() + { + tester().assertThat(ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.join( + ASOF, + p.values(a), + p.values(b), + comparison(LESS_THAN, add(b, 1), a.toSymbolReference())); + }) + .matches( + join(ASOF, builder -> builder + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a"))) + .left(values("a")) + .right(project( + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), + values("b"))))); + } + + @Test + public void testJoinFilterExpressionPushedDownToRightJoinSourceAsofLeft() + { + tester().assertThat(ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.join( + ASOF_LEFT, + p.values(a), + p.values(b), + comparison(LESS_THAN, add(b, 1), a.toSymbolReference())); + }) + .matches( + join(ASOF_LEFT, builder -> builder + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a"))) + .left(values("a")) + .right(project( + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), + values("b"))))); + } + @Test public void testManyJoinFilterExpressionsPushedDownToRightJoinSource() { @@ -219,6 +265,49 @@ public void testManyParentFilterExpressionsPushedDownToRightJoinSource() values("b"))))))); } + @Test + public void testParentFilterExpressionPushedDownToRightJoinSourceAsof() + { + tester().assertThat(ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.filter( + comparison(LESS_THAN, add(b, 1), a.toSymbolReference()), + p.join( + ASOF, + p.values(a), + p.values(b))); + }) + .matches( + project( + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a")), + join(ASOF, builder -> builder + .left(values("a")) + .right( + project( + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), + values("b"))))))); + } + + @Test + public void testParentFilterExpressionNotPushedDownToRightJoinSourceAsofLeft() + { + tester().assertThat(ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.filter( + comparison(LESS_THAN, add(b, 1), a.toSymbolReference()), + p.join( + ASOF_LEFT, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + @Test public void testOnlyParentFilterExpressionExposedInaJoin() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index 7e0981f7d9b3..d6c6f2db67cb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -539,6 +539,53 @@ public void testPushJoinIntoTableDoesNotFireForCrossJoin() } } + @Test + public void testPushJoinIntoTableDoesNotFireForAsofJoin() + { + assertAsofJoinDoesNotFire(io.trino.sql.planner.plan.JoinType.ASOF); + } + + @Test + public void testPushJoinIntoTableDoesNotFireForAsofLeftJoin() + { + assertAsofJoinDoesNotFire(io.trino.sql.planner.plan.JoinType.ASOF_LEFT); + } + + private void assertAsofJoinDoesNotFire(io.trino.sql.planner.plan.JoinType joinType) + { + MockConnectorFactory connectorFactory = createMockConnectorFactory( + (_, _, _, _, _, _, _, _) -> { + throw new IllegalStateException("applyJoin should not be called!"); + }); + try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext())) + .withSession(MOCK_SESSION) + .on(p -> { + Symbol columnA1Symbol = p.symbol(COLUMN_A1); + Symbol columnA2Symbol = p.symbol(COLUMN_A2); + Symbol columnB1Symbol = p.symbol(COLUMN_B1); + + TableScanNode left = p.tableScan( + ruleTester.getCurrentCatalogTableHandle(SCHEMA, TABLE_A), + ImmutableList.of(columnA1Symbol, columnA2Symbol), + ImmutableMap.of( + columnA1Symbol, COLUMN_A1_HANDLE, + columnA2Symbol, COLUMN_A2_HANDLE)); + TableScanNode right = p.tableScan( + ruleTester.getCurrentCatalogTableHandle(SCHEMA, TABLE_B), + ImmutableList.of(columnB1Symbol), + ImmutableMap.of(columnB1Symbol, COLUMN_B1_HANDLE)); + + return p.join( + joinType, + left, + right, + new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, columnA1Symbol.toSymbolReference(), columnB1Symbol.toSymbolReference())); + }) + .doesNotFire(); + } + } + @Test public void testPushJoinIntoTableRequiresFullColumnHandleMappingInResult() { @@ -629,6 +676,7 @@ private JoinType toSpiJoinType(io.trino.sql.planner.plan.JoinType joinType) case LEFT -> JoinType.LEFT_OUTER; case RIGHT -> JoinType.RIGHT_OUTER; case FULL -> JoinType.FULL_OUTER; + case ASOF, ASOF_LEFT -> throw new UnsupportedOperationException("ASOF join pushdown is not supported"); }; } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantJoin.java index 0610f7a5a081..10709c70f2c4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantJoin.java @@ -17,6 +17,8 @@ import org.junit.jupiter.api.Test; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -100,4 +102,39 @@ public void testFullJoinRemoval() p.values(0, p.symbol("b")))) .matches(values("a", "b")); } + + @Test + public void testAsofJoinRemoval() + { + // Right empty + tester().assertThat(new RemoveRedundantJoin()) + .on(p -> + p.join( + ASOF, + p.values(10, p.symbol("a")), + p.values(0))) + .matches(values("a")); + + // Left empty + tester().assertThat(new RemoveRedundantJoin()) + .on(p -> + p.join( + ASOF, + p.values(0), + p.values(10, p.symbol("b")))) + .matches(values("b")); + } + + @Test + public void testAsofLeftJoinRemoval() + { + // Left empty + tester().assertThat(new RemoveRedundantJoin()) + .on(p -> + p.join( + ASOF_LEFT, + p.values(0, p.symbol("a")), + p.values(10, p.symbol("b")))) + .matches(values("a", "b")); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java index 59a568f43893..d87f4c6e3176 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java @@ -39,6 +39,8 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -201,6 +203,34 @@ public void testReplaceInnerJoinWithProject() values("c"))); } + @Test + public void testReplaceAsofJoinWithProject() + { + // For ASOF, only right constant source is inlined + tester().assertThat(new ReplaceJoinOverConstantWithProject()) + .on(p -> + p.join( + ASOF, + p.values(5, p.symbol("c")), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) + .matches( + project( + ImmutableMap.of( + "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), + values("c"))); + + // Left constant source should NOT be inlined for ASOF + tester().assertThat(new ReplaceJoinOverConstantWithProject()) + .on(p -> + p.join( + ASOF, + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), + p.values(5, p.symbol("c")))) + .doesNotFire(); + } + @Test public void testReplaceLeftJoinWithProject() { @@ -233,6 +263,34 @@ public void testReplaceLeftJoinWithProject() values("c"))); } + @Test + public void testReplaceAsofLeftJoinWithProject() + { + // For ASOF LEFT, only right constant source is inlined + tester().assertThat(new ReplaceJoinOverConstantWithProject()) + .on(p -> + p.join( + ASOF_LEFT, + p.values(5, p.symbol("c")), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) + .matches( + project( + ImmutableMap.of( + "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), + values("c"))); + + // Left constant source should NOT be inlined for ASOF LEFT + tester().assertThat(new ReplaceJoinOverConstantWithProject()) + .on(p -> + p.join( + ASOF_LEFT, + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), + p.values(5, p.symbol("c")))) + .doesNotFire(); + } + @Test public void testReplaceRightJoinWithProject() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java index 2bcd116780d1..bcb319587a19 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java @@ -24,6 +24,8 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinType.ASOF; +import static io.trino.sql.planner.plan.JoinType.ASOF_LEFT; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -43,6 +45,15 @@ public void testDoesNotFireOnInnerJoin() p.values(0, p.symbol("a")), p.values(0, p.symbol("b")))) .doesNotFire(); + + // ASOF is treated like INNER in this rule (never fires) + tester().assertThat(new ReplaceRedundantJoinWithProject()) + .on(p -> + p.join( + ASOF, + p.values(0, p.symbol("a")), + p.values(0, p.symbol("b")))) + .doesNotFire(); } @Test @@ -56,6 +67,15 @@ public void testDoesNotFireWhenOuterSourceEmpty() p.values(0, p.symbol("b")))) .doesNotFire(); + // ASOF LEFT behaves like LEFT here; outer (left) empty -> no fire + tester().assertThat(new ReplaceRedundantJoinWithProject()) + .on(p -> + p.join( + ASOF_LEFT, + p.values(0, p.symbol("a")), + p.values(0, p.symbol("b")))) + .doesNotFire(); + tester().assertThat(new ReplaceRedundantJoinWithProject()) .on(p -> p.join( @@ -92,6 +112,20 @@ public void testReplaceLeftJoin() "a", expression(new Reference(BIGINT, "a")), "b", expression(new Constant(BIGINT, null))), values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); + + // ASOF LEFT behaves like LEFT; replace with left and append nulls for right outputs + tester().assertThat(new ReplaceRedundantJoinWithProject()) + .on(p -> + p.join( + ASOF_LEFT, + p.values(10, p.symbol("a")), + p.values(0, p.symbol("b")))) + .matches( + project( + ImmutableMap.of( + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Constant(BIGINT, null))), + values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java index 512bf80dd532..705c5238360e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java @@ -27,9 +27,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinType.ASOF; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -268,6 +270,37 @@ public void testReplaceFullJoin() .doesNotFire(); } + @Test + public void testDoesNotReplaceAsofJoinWithLeftScalarNoOutputs() + { + // ASOF join with left scalar (no outputs) should not be replaced + tester().assertThat(new ReplaceRedundantJoinWithSource()) + .on(p -> + p.join( + ASOF, + p.values(1), + p.values(10, p.symbol("b", BIGINT)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b"), new Constant(BIGINT, 0L)))) + .doesNotFire(); + } + + @Test + public void testReplaceAsofJoinWithRightScalarNoOutputs() + { + // ASOF join with right scalar (no outputs) should be replaced with left source and preserve filter + tester().assertThat(new ReplaceRedundantJoinWithSource()) + .on(p -> + p.join( + ASOF, + p.values(10, p.symbol("a", BIGINT)), + p.values(1), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)))) + .matches( + filter( + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)), + values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); + } + @Test public void testPruneOutputs() { diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java index d91e2105735e..f9b56abcfcc7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java @@ -136,6 +136,275 @@ WHERE CAST(x AS bigint) IS NOT NULL AND y = 'hello' .isEmpty(); } + @Test + public void testAsofJoinExecution() + { + // <= direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (1, 1), (2, 1), (3, 2)), + b(t, k, v) AS (VALUES (1, 1, 'x'), (2, 1, 'y'), (4, 1, 'z'), (1, 2, 'q')) + SELECT a.t, a.k, b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t <= a.t + ORDER BY 1 + """)) + .matches("VALUES (1, 1, 'x'), (2, 1, 'y'), (3, 2, 'q')"); + + // >= direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (1, 1), (2, 1), (3, 2)), + b(t, k, v) AS (VALUES (1, 1, 'x'), (2, 1, 'y'), (4, 1, 'z'), (1, 2, 'q')) + SELECT a.t, a.k, b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t >= a.t + ORDER BY 1 + """)) + .matches("VALUES (1, 1, 'x'), (2, 1, 'y')"); + } + + @Test + public void testAsofJoinWithNullInequalityColumn() + { + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (2, 1), (4, 1)), + b(t, k, v) AS (VALUES (CAST(NULL AS INTEGER), 1, 'null'), (1, 1, 'one'), (3, 1, 'three')) + SELECT a.t, a.k, b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t <= a.t + ORDER BY 1 + """)) + .matches("VALUES (2, 1, 'one'), (4, 1, 'three')"); + } + + @Test + public void testAsofLeftJoinExecution() + { + // <= direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (1, 1), (5, 1), (3, 3)), + b(t, k, v) AS (VALUES (2, 1, 'm'), (4, 1, 'n')) + SELECT a.t, a.k, b.v + FROM a ASOF LEFT JOIN b ON a.k = b.k AND b.t <= a.t + ORDER BY 1 + """)) + .matches("VALUES (1, 1, NULL), (3, 3, NULL), (5, 1, 'n')"); + + // >= direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (1, 1), (5, 1), (3, 3)), + b(t, k, v) AS (VALUES (2, 1, 'm'), (4, 1, 'n')) + SELECT a.t, a.k, b.v + FROM a ASOF LEFT JOIN b ON a.k = b.k AND b.t >= a.t + ORDER BY 1 + """)) + .matches("VALUES (1, 1, 'm'), (3, 3, NULL), (5, 1, NULL)"); + } + + @Test + public void testAsofLeftJoinWithNullInequalityColumn() + { + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (2, 1), (5, 1)), + b(t, k, v) AS (VALUES (CAST(NULL AS INTEGER), 1, 'null'), (3, 1, 'three')) + SELECT a.t, a.k, b.v + FROM a ASOF LEFT JOIN b ON a.k = b.k AND b.t <= a.t + ORDER BY 1 + """)) + .matches("VALUES (2, 1, NULL), (5, 1, 'three')"); + } + + @Test + public void testAsofJoinStrictInequality() + { + // strict < direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (2, 1)), + b(t, k, v) AS (VALUES (1, 1, 'less_'), (2, 1, 'equal')) + SELECT b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t < a.t + """)) + .matches("VALUES 'less_'"); + + // strict > direction + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (2, 1)), + b(t, k, v) AS (VALUES (1, 1, 'less_'), (2, 1, 'equal')) + SELECT b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t > a.t + """)) + .returnsEmptyResult(); + } + + @Test + public void testAsofJoinWithAdditionalBuildFilter() + { + // Additional build-side predicate (b.t % 2 = 0) filters out closest candidate (odd timestamp), + // so the join selects the next valid candidate satisfying the inequality. + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (3, 1), (5, 1)), + b(t, k, v) AS (VALUES (2, 1, 'x2'), (4, 1, 'x4'), (5, 1, 'x5')) + SELECT a.t, a.k, b.v + FROM a ASOF JOIN b ON a.k = b.k AND b.t <= a.t AND (b.t % 2) = 0 + ORDER BY 1 + """)) + .matches("VALUES (3, 1, 'x2'), (5, 1, 'x4')"); + } + + @Test + public void testAsofJoinWithoutEquiCondition() + { + // ASOF join using only inequality (no equi-conjunct) + assertThat(assertions.query( + """ + WITH + a(t) AS (VALUES 0, 1, 3), + b(t, v) AS (VALUES (1, 'x1'), (2, 'x2')) + SELECT a.t, b.v + FROM a ASOF JOIN b ON b.t <= a.t + ORDER BY 1 + """)) + .matches("VALUES (1, 'x1'), (3, 'x2')"); + + // Reverse direction with only inequality (>=) + assertThat(assertions.query( + """ + WITH + a(t) AS (VALUES 0, 1, 3), + b(t, v) AS (VALUES (1, 'x1'), (2, 'x2')) + SELECT a.t, b.v + FROM a ASOF JOIN b ON b.t >= a.t + ORDER BY 1 + """)) + .matches("VALUES (0, 'x1'), (1, 'x1')"); + } + + @Test + public void testAsofJoinWithExtraInequalities() + { + // Extra right-only inequality (b.t < 3) excludes the closest candidate for a.t = 3, + // changing the chosen match from b.t = 3 to b.t = 2. + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (2, 1), (3, 1)), + b(t, k, v) AS (VALUES (1, 1, 'v1'), (2, 1, 'v2'), (2, 1, 'v2'), (3, 1, 'v3')) + SELECT a.t, a.k, b.v + FROM a ASOF JOIN b + ON a.k = b.k AND b.t <= a.t AND b.t < 3 AND a.t > 2 + ORDER BY 1 + """)) + .matches("VALUES (3, 1, 'v2')"); + } + + @Test + public void testAsofJoinOrdersLikeWithValuesUsingCustkeyBound() + { + assertThat(assertions.query( + """ + WITH + o1(orderkey, custkey) AS (VALUES (1, 1), (2, 1), (3, 2)), + o2(orderkey, custkey, v) AS (VALUES (1, 1, 'x'), (2, 1, 'y'), (4, 1, 'z'), (1, 2, 'q'), (2, 2, 'v')) + SELECT o1.custkey, o2.v + FROM o1 ASOF JOIN o2 + ON o1.custkey = o2.custkey AND o2.orderkey <= o1.custkey + ORDER BY 1, 2 + """)) + .matches("VALUES (1, 'x'), (1, 'x'), (2, 'v')"); + } + + @Test + public void testCorrelatedSubqueryInAsofJoinClause() + { + // Correlation in ASOF join clause is not allowed (treated like outer join for correlation purposes) + assertThat(assertions.query( + "SELECT * FROM (VALUES 1, 2) t(x) ASOF JOIN (VALUES 1, 3) u(x) ON t.x IN (SELECT v.x FROM (VALUES 1, 2) v(x) WHERE u.x = v.x)")) + .failure() + .hasMessageContaining("Reference to column 'u.x' from outer scope not allowed in this context"); + + assertThat(assertions.query( + "SELECT * FROM (VALUES 1, 2) t(x) ASOF JOIN (VALUES 1, 3) u(x) ON u.x IN (SELECT v.x FROM (VALUES 1, 2) v(x) WHERE t.x = v.x)")) + .failure() + .hasMessageContaining("Reference to column 't.x' from outer scope not allowed in this context"); + } + + @Test + public void testAsofJoinBuildSideComplexExpressionPushdown() + { + // Verify correctness when inequality uses complex build-side expression + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (3, 1), (2, 1)), + b(t, k) AS (VALUES (1, 1), (2, 1)) + SELECT a.t, a.k, b.t + FROM a ASOF JOIN b ON a.k = b.k AND (b.t + 1) <= a.t + ORDER BY 1 + """)) + .matches("VALUES (2, 1, 1), (3, 1, 2)"); + } + + @Test + public void testAsofJoinBuildSideComplexExpressionPushdownNoBuildOutputs() + { + // Same as above, but ensure no build-side columns are projected in the output + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (3, 1), (2, 1)), + b(t, k) AS (VALUES (1, 1), (2, 1)) + SELECT a.t, a.k + FROM a ASOF JOIN b ON a.k = b.k AND (b.t + 1) <= a.t + ORDER BY 1 + """)) + .matches("VALUES (2, 1), (3, 1)"); + } + + @Test + public void testAsofJoinBuildSideCastExpression() + { + // Build-side inequality uses CAST on the build column + assertThat(assertions.query( + """ + WITH + a(t, k) AS (VALUES (3, 1), (2, 1), (5, 1)), + b(t, k) AS (VALUES (CAST('1' AS varchar), 1), (CAST('2' AS varchar), 1), (CAST('4' AS varchar), 1)) + SELECT a.t, a.k, b.t + FROM a ASOF JOIN b ON a.k = b.k AND (CAST(b.t AS integer) + 1) <= a.t + ORDER BY 1 + """)) + .matches("VALUES (2, 1, CAST('1' AS VARCHAR)), (3, 1, CAST('2' AS VARCHAR)), (5, 1, CAST('4' AS VARCHAR))"); + } + + @Test + public void testCorrelatedSubqueryInAsofLeftJoinClause() + { + // Correlation in ASOF LEFT join clause is not allowed + assertThat(assertions.query( + "SELECT * FROM (VALUES 1, 2) t(x) ASOF LEFT JOIN (VALUES 1, 3) u(x) ON t.x IN (SELECT v.x FROM (VALUES 1, 2) v(x) WHERE u.x = v.x)")) + .failure() + .hasMessageContaining("Reference to column 'u.x' from outer scope not allowed in this context"); + + assertThat(assertions.query( + "SELECT * FROM (VALUES 1, 2) t(x) ASOF LEFT JOIN (VALUES 1, 3) u(x) ON u.x IN (SELECT v.x FROM (VALUES 1, 2) v(x) WHERE t.x = v.x)")) + .failure() + .hasMessageContaining("Reference to column 't.x' from outer scope not allowed in this context"); + } + @Test public void testInPredicateInJoinCriteria() { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 8a33fe5e4916..2a6af21a6da4 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -1963,7 +1963,16 @@ else if (context.joinCriteria().USING() != null) { } Join.Type joinType; - if (context.joinType().LEFT() != null) { + Token joinTypeStart = context.joinType().getStart(); + if (joinTypeStart.getType() == SqlBaseLexer.ASOF) { + if (context.joinType().LEFT() != null) { + joinType = Join.Type.ASOF_LEFT; + } + else { + joinType = Join.Type.ASOF; + } + } + else if (context.joinType().LEFT() != null) { joinType = Join.Type.LEFT; } else if (context.joinType().RIGHT() != null) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java index 8063052b3bfd..9a1c00e4530f 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java @@ -28,7 +28,13 @@ public class Join { public enum Type { - CROSS, INNER, LEFT, RIGHT, FULL, IMPLICIT + CROSS, INNER, LEFT, RIGHT, FULL, ASOF, ASOF_LEFT, IMPLICIT; + + @Override + public String toString() + { + return name().replace('_', ' '); + } } public Join(Type type, Relation left, Relation right, Optional criteria) diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 8f9f0711e720..5efa8f67e248 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -4864,6 +4864,20 @@ public void testJoinPrecedence() Optional.of(new NaturalJoin())))); } + @Test + public void testAsofJoinParsing() + { + Query asofJoinQuery = (Query) SQL_PARSER.createStatement("SELECT * FROM a ASOF JOIN b ON a.k = b.k AND b.t <= a.t"); + QuerySpecification asofSpecification = (QuerySpecification) asofJoinQuery.getQueryBody(); + Join asofJoin = (Join) asofSpecification.getFrom().orElseThrow(); + assertThat(asofJoin.getType()).isEqualTo(Join.Type.ASOF); + + Query asofLeftJoinQuery = (Query) SQL_PARSER.createStatement("SELECT * FROM a ASOF LEFT JOIN b ON b.t <= a.t"); + QuerySpecification asofLeftSpecification = (QuerySpecification) asofLeftJoinQuery.getQueryBody(); + Join asofLeftJoin = (Join) asofLeftSpecification.getFrom().orElseThrow(); + assertThat(asofLeftJoin.getType()).isEqualTo(Join.Type.ASOF_LEFT); + } + @Test public void testUnnest() { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index c9e0f2e6dd2f..49dad3193f5c 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -50,7 +50,7 @@ private static Stream statements() Arguments.of("select * from 'oops", "line 1:15: mismatched input '''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'TABLE', 'UNNEST', "), Arguments.of("select *\nfrom x\nfrom", - "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', " + + "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'ASOF', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', " + "'LIMIT', 'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', , "), Arguments.of("select *\nfrom x\nwhere from", "line 3:7: mismatched input 'from'. Expecting: "), @@ -121,7 +121,7 @@ private static Stream statements() Arguments.of("SELECT foo(*) filter (", "line 1:23: mismatched input ''. Expecting: 'WHERE'"), Arguments.of("SELECT * FROM t t x", - "line 1:19: mismatched input 'x'. Expecting: '(', ',', 'CROSS', 'EXCEPT', 'FETCH', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', " + + "line 1:19: mismatched input 'x'. Expecting: '(', ',', 'ASOF', 'CROSS', 'EXCEPT', 'FETCH', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', " + "'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', "), Arguments.of("SELECT * FROM t WHERE EXISTS (", "line 1:31: mismatched input ''. Expecting: "), diff --git a/docs/src/main/sphinx/language/reserved.md b/docs/src/main/sphinx/language/reserved.md index 5edee925719d..1e73d2b2e81b 100644 --- a/docs/src/main/sphinx/language/reserved.md +++ b/docs/src/main/sphinx/language/reserved.md @@ -12,6 +12,7 @@ be quoted (using double quotes) in order to be used as an identifier. | `ALTER` | reserved | reserved | | `AND` | reserved | reserved | | `AS` | reserved | reserved | +| `ASOF` | | | | `AUTO` | | | | `BETWEEN` | reserved | reserved | | `BY` | reserved | reserved | diff --git a/docs/src/main/sphinx/sql/select.md b/docs/src/main/sphinx/sql/select.md index b50e66d29b8f..a2693785be70 100644 --- a/docs/src/main/sphinx/sql/select.md +++ b/docs/src/main/sphinx/sql/select.md @@ -52,6 +52,7 @@ and `join_type` is one of LEFT [ OUTER ] JOIN RIGHT [ OUTER ] JOIN FULL [ OUTER ] JOIN +ASOF [ LEFT ] JOIN CROSS JOIN ``` @@ -1402,6 +1403,61 @@ For more information, see [`JSON_TABLE`](json-table). Joins allow you to combine data from multiple relations. +### ASOF JOIN + +An ASOF join matches each row on the left side to the "closest" row on the right side according to an inequality predicate comparing an expression from the right relation to an expression from the left relation. It is commonly used for time or sequence based nearest-neighbor lookups. + +Syntax: + +``` +SELECT ... +FROM left_relation +ASOF [ LEFT ] JOIN right_relation + ON AND +``` + +Rules and behavior: + +- Inequality: Exactly one inequality predicate is required that compares an expression from the right relation to an expression from the left relation using one of `<`, `<=`, `>` or `>=`. +- Equality: Any number of equality conjuncts (e.g., `left.k = right.k`) can appear in the `ON` clause. +- Closest match: The engine picks the nearest matching right-side row allowed by the direction of the inequality and any additional predicates in the `ON` clause. +- `ASOF LEFT`: Preserves all left rows, returning `NULL`s for right columns when there is no match. + +Examples: + +Latest quote at trade time: + +``` +WITH + trades(ts, sym) AS (VALUES (TIMESTAMP '2024-01-01 09:30:05', 'ABC'), + (TIMESTAMP '2024-01-01 09:30:07', 'ABC')), + quotes(ts, sym, price) AS (VALUES (TIMESTAMP '2024-01-01 09:30:00', 'ABC', 100.0), + (TIMESTAMP '2024-01-01 09:30:06', 'ABC', 101.5)) +SELECT trades.ts, quotes.price +FROM trades ASOF JOIN quotes + ON trades.sym = quotes.sym AND quotes.ts <= trades.ts +ORDER BY 1; +-- 2024-01-01 09:30:05 100.0 +-- 2024-01-01 09:30:07 101.5 +``` + +ASOF LEFT with missing match: + +``` +WITH + readings(read_time) AS ( + VALUES (TIMESTAMP '2024-01-01 10:00:00'), + (TIMESTAMP '2024-01-01 10:05:00')), + calibrations(effective_time, factor) AS ( + VALUES (TIMESTAMP '2024-01-01 10:01:00', 1.02)) +SELECT readings.read_time, calibrations.factor +FROM readings ASOF LEFT JOIN calibrations + ON calibrations.effective_time >= readings.read_time +ORDER BY 1; +-- 2024-01-01 10:00:00 1.02 +-- 2024-01-01 10:05:00 NULL +``` + ### CROSS JOIN A cross join returns the Cartesian product (all combinations) of two