diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index da5f3363be2c..890fe0ea64ad 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -481,6 +481,10 @@ relationPrimary )? ((ERROR | EMPTY) ON ERROR)? ')' #jsonTable + | NEAREST '(' + FROM relation + (WHERE where=booleanExpression)? + MATCH match=booleanExpression ')' #nearest ; jsonTableColumn @@ -1052,7 +1056,7 @@ nonReserved | KEEP | KEY | KEYS | LANGUAGE | LAST | LATERAL | LEADING | LEAVE | LEVEL | LIMIT | LOCAL | LOGICAL | LOOP | MAP | MATCH | MATCHED | MATCHES | MATCH_RECOGNIZE | MATERIALIZED | MEASURES | MERGE | MINUTE | MONTH - | NESTED | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS + | NESTED | NEXT | NFC | NFD | NEAREST | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OBJECT | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | OVERFLOW | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD | PERMUTE | PLAN | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE | QUOTES @@ -1232,6 +1236,7 @@ MONTH: 'MONTH'; NATURAL: 'NATURAL'; NESTED: 'NESTED'; NEXT: 'NEXT'; +NEAREST: 'NEAREST'; NFC : 'NFC'; NFD : 'NFD'; NFKC : 'NFKC'; diff --git a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java index cb44176dc6ef..028ee53640ef 100644 --- a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java +++ b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java @@ -190,6 +190,7 @@ public void test() "MONTH", "NATURAL", "NESTED", + "NEAREST", "NEXT", "NFC", "NFD", diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 9fe1a646f3f9..87e919c65072 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -60,6 +60,7 @@ import io.trino.sql.analyzer.PatternRecognitionAnalysis.PatternInputAnalysis; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.tree.AllColumns; +import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.DataType; import io.trino.sql.tree.ExistsPredicate; import io.trino.sql.tree.Expression; @@ -73,6 +74,7 @@ import io.trino.sql.tree.JsonTableColumnDefinition; import io.trino.sql.tree.LambdaArgumentDeclaration; import io.trino.sql.tree.MeasureDefinition; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.Offset; @@ -242,6 +244,7 @@ public class Analysis private final Map, Map> defaultColumnValues = new LinkedHashMap<>(); private final Map, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>(); + private final Map, NearestAnalysis> nearestAnalysis = new LinkedHashMap<>(); private Optional create = Optional.empty(); private Optional insert = Optional.empty(); private Optional refreshMaterializedView = Optional.empty(); @@ -1001,6 +1004,16 @@ public UnnestAnalysis getUnnest(Unnest node) return unnestAnalysis.get(NodeRef.of(node)); } + public void setNearest(Nearest node, NearestAnalysis analysis) + { + nearestAnalysis.put(NodeRef.of(node), analysis); + } + + public NearestAnalysis getNearest(Nearest node) + { + return nearestAnalysis.get(NodeRef.of(node)); + } + public void addTableColumnReferences(AccessControl accessControl, Identity identity, Multimap tableColumnMap) { AccessControlInfo accessControlInfo = new AccessControlInfo(accessControl, identity); @@ -2609,6 +2622,17 @@ public record JsonTableAnalysis( } } + public record NearestAnalysis( + ComparisonExpression.Operator operator, + Expression candidateExpression) + { + public NearestAnalysis + { + requireNonNull(operator, "operator is null"); + requireNonNull(candidateExpression, "candidateExpression is null"); + } + } + public record CorrespondingAnalysis(List indexes, List fields) { public CorrespondingAnalysis diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index c571a7e7d7bc..43b95a80fe7d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -110,6 +110,7 @@ import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; import io.trino.sql.analyzer.Analysis.JsonTableAnalysis; import io.trino.sql.analyzer.Analysis.MergeAnalysis; +import io.trino.sql.analyzer.Analysis.NearestAnalysis; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.Analysis.SourceColumn; @@ -136,6 +137,7 @@ import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; +import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateMaterializedView; @@ -198,6 +200,7 @@ import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; @@ -414,6 +417,7 @@ import static io.trino.sql.analyzer.ExpressionTreeUtils.extractWindowMeasures; import static io.trino.sql.analyzer.Scope.BasisType.TABLE; import static io.trino.sql.analyzer.ScopeReferenceExtractor.getReferencesToScope; +import static io.trino.sql.analyzer.ScopeReferenceExtractor.hasReferencesToScope; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; @@ -1702,6 +1706,62 @@ protected Scope visitLateral(Lateral node, Optional scope) return createAndAssignScope(node, scope, queryScope.getRelationType()); } + @Override + protected Scope visitNearest(Nearest node, Optional scope) + { + if (scope.isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "NEAREST is only supported on the right side of CROSS JOIN, INNER JOIN, LEFT JOIN, or an implicit join"); + } + + Scope leftScope = scope.orElseThrow(); + if (leftScope.getRelationType().getAllFieldCount() == 0) { + throw semanticException(NOT_SUPPORTED, node, "NEAREST is only supported on the right side of CROSS JOIN, INNER JOIN, LEFT JOIN, or an implicit join"); + } + + // NEAREST is treated as a lateral relation by visitJoin(), but in the current implementation only the + // top-level WHERE and MATCH clauses may correlate with the left side of the join. + // The inner FROM relation is analyzed against the query boundary so it behaves like an + // uncorrelated relation body. + // + // TODO: If NEAREST is extended to allow correlation inside FROM , analyze the + // FROM relation against Optional.of(leftScope) instead, similar to LATERAL. + Scope sourceScope = process(node.getRelation(), Optional.of(leftScope.getQueryBoundaryScope())); + + // Re-wrap the analyzed FROM relation in a scope whose parent is the left side of the + // join. This allows WHERE and MATCH to reference both the FROM relation fields and + // the left-side fields while keeping the inner FROM relation itself uncorrelated. + Scope nearestScope = createAndAssignScope(node, scope, sourceScope.getRelationType()); + + node.getWhere().ifPresent(where -> { + verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, where, "NEAREST WHERE clause"); + ExpressionAnalysis whereAnalysis = analyzeExpression(where, nearestScope, CorrelationSupport.ALLOWED); + Type whereType = whereAnalysis.getType(where); + if (!whereType.equals(BOOLEAN)) { + if (!whereType.equals(UNKNOWN)) { + throw semanticException(TYPE_MISMATCH, where, "NEAREST WHERE clause must evaluate to a boolean: actual type %s", whereType); + } + analysis.addCoercion(where, BOOLEAN); + } + verifyNoCorrelatedSubqueries(where, leftScope, sourceScope, "NEAREST WHERE clause"); + analysis.recordSubqueries(node, whereAnalysis); + }); + + verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, node.getMatch(), "NEAREST MATCH clause"); + ExpressionAnalysis matchAnalysis = analyzeExpression(node.getMatch(), nearestScope, CorrelationSupport.ALLOWED); + Type matchType = matchAnalysis.getType(node.getMatch()); + if (!matchType.equals(BOOLEAN)) { + if (!matchType.equals(UNKNOWN)) { + throw semanticException(TYPE_MISMATCH, node.getMatch(), "NEAREST MATCH clause must evaluate to a boolean: actual type %s", matchType); + } + analysis.addCoercion(node.getMatch(), BOOLEAN); + } + verifyNoCorrelatedSubqueries(node.getMatch(), leftScope, sourceScope, "NEAREST MATCH clause"); + analysis.recordSubqueries(node, matchAnalysis); + analysis.setNearest(node, analyzeNearestMatch(node, nearestScope, leftScope)); + + return nearestScope; + } + @Override protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional scope) { @@ -3443,32 +3503,43 @@ protected Scope visitJoin(Join node, Optional scope) }); } if (isUnnestRelation(node.getRight())) { - if (criteria != null) { - if (!(criteria instanceof JoinOn joinOn) || !joinOn.getExpression().equals(TRUE_LITERAL)) { - throw semanticException( - NOT_SUPPORTED, - criteria instanceof JoinOn joinOn ? joinOn.getExpression() : node, - "%s JOIN involving UNNEST is only supported with condition ON TRUE", - node.getType().name()); - } + if (criteria != null && !isJoinOnTrue(criteria)) { + throw semanticException( + NOT_SUPPORTED, + getJoinErrorLocation(node, criteria), + "%s JOIN involving UNNEST is only supported with condition ON TRUE", + node.getType().name()); } } else if (isJsonTable(node.getRight())) { - if (criteria != null) { - if (!(criteria instanceof JoinOn joinOn) || !joinOn.getExpression().equals(TRUE_LITERAL)) { - throw semanticException( - NOT_SUPPORTED, - criteria instanceof JoinOn joinOn ? joinOn.getExpression() : node, - "%s JOIN involving JSON_TABLE is only supported with condition ON TRUE", - node.getType().name()); - } + if (criteria != null && !isJoinOnTrue(criteria)) { + throw semanticException( + NOT_SUPPORTED, + getJoinErrorLocation(node, criteria), + "%s JOIN involving JSON_TABLE is only supported with condition ON TRUE", + node.getType().name()); + } + } + else if (isNearestRelation(node.getRight())) { + if (criteria instanceof JoinUsing) { + throw semanticException(NOT_SUPPORTED, node, "JOIN USING involving NEAREST is not supported"); + } + if ((node.getType() == Join.Type.INNER || node.getType() == LEFT) && !isJoinOnTrue(criteria)) { + throw semanticException( + NOT_SUPPORTED, + getJoinErrorLocation(node, criteria), + "%s JOIN involving NEAREST is only supported with condition ON TRUE", + node.getType().name()); + } + if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT && node.getType() != Join.Type.INNER && node.getType() != LEFT) { + throw semanticException(NOT_SUPPORTED, node, "%s JOIN involving NEAREST is not supported", node.getType().name()); } } else if (node.getType() == FULL) { - if (!(criteria instanceof JoinOn joinOn) || !joinOn.getExpression().equals(TRUE_LITERAL)) { + if (!isJoinOnTrue(criteria)) { throw semanticException( NOT_SUPPORTED, - criteria instanceof JoinOn joinOn ? joinOn.getExpression() : node, + getJoinErrorLocation(node, criteria), "FULL JOIN involving LATERAL relation is only supported with condition ON TRUE"); } } @@ -4047,7 +4118,7 @@ private boolean isLateralRelation(Relation node) if (node instanceof AliasedRelation aliasedRelation) { return isLateralRelation(aliasedRelation.getRelation()); } - return node instanceof Unnest || node instanceof Lateral || node instanceof JsonTable; + return node instanceof Unnest || node instanceof Lateral || node instanceof JsonTable || node instanceof Nearest; } private boolean isUnnestRelation(Relation node) @@ -4066,6 +4137,69 @@ private boolean isJsonTable(Relation node) return node instanceof JsonTable; } + private boolean isNearestRelation(Relation node) + { + if (node instanceof AliasedRelation aliasedRelation) { + return isNearestRelation(aliasedRelation.getRelation()); + } + return node instanceof Nearest; + } + + private static boolean isJoinOnTrue(JoinCriteria criteria) + { + return criteria instanceof JoinOn joinOn && joinOn.getExpression().equals(TRUE_LITERAL); + } + + private static Node getJoinErrorLocation(Join join, JoinCriteria criteria) + { + if (criteria instanceof JoinOn joinOn) { + return joinOn.getExpression(); + } + return join; + } + + private void verifyNoCorrelatedSubqueries(Expression expression, Scope leftScope, Scope sourceScope, String description) + { + for (SubqueryExpression subquery : extractExpressions(ImmutableList.of(expression), SubqueryExpression.class)) { + if (hasReferencesToScope(subquery, analysis, leftScope) || hasReferencesToScope(subquery, analysis, sourceScope)) { + throw semanticException(UNSUPPORTED_SUBQUERY, subquery, "Correlated subqueries are not supported in %s", description); + } + } + } + + private NearestAnalysis analyzeNearestMatch(Nearest node, Scope nearestScope, Scope leftScope) + { + if (!(node.getMatch() instanceof ComparisonExpression comparison)) { + throw semanticException(NOT_SUPPORTED, node.getMatch(), "NEAREST MATCH clause must be a comparison expression"); + } + + // MATCH is analyzed in nearestScope, which exposes fields from the FROM relation locally + // and the left join input through the parent scope. + boolean leftReferencesFromRelation = hasReferencesToScope(comparison.getLeft(), analysis, nearestScope); + boolean rightReferencesFromRelation = hasReferencesToScope(comparison.getRight(), analysis, nearestScope); + boolean leftReferencesOuterRelation = hasReferencesToScope(comparison.getLeft(), analysis, leftScope); + boolean rightReferencesOuterRelation = hasReferencesToScope(comparison.getRight(), analysis, leftScope); + if (leftReferencesFromRelation == rightReferencesFromRelation) { + throw semanticException(NOT_SUPPORTED, node.getMatch(), "NEAREST MATCH clause must compare one FROM relation expression with one non-FROM expression"); + } + + if ((leftReferencesFromRelation && leftReferencesOuterRelation) || (rightReferencesFromRelation && rightReferencesOuterRelation)) { + throw semanticException(NOT_SUPPORTED, node.getMatch(), "NEAREST MATCH clause must keep FROM relation and non-FROM expressions on opposite sides"); + } + + Expression candidateExpression = leftReferencesFromRelation ? comparison.getLeft() : comparison.getRight(); + + ComparisonExpression.Operator operator = leftReferencesFromRelation ? comparison.getOperator() : comparison.getOperator().flip(); + if (operator != ComparisonExpression.Operator.LESS_THAN && + operator != ComparisonExpression.Operator.LESS_THAN_OR_EQUAL && + operator != ComparisonExpression.Operator.GREATER_THAN && + operator != ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL) { + throw semanticException(NOT_SUPPORTED, node.getMatch(), "NEAREST MATCH clause must use <, <=, >, or >="); + } + + return new NearestAnalysis(operator, candidateExpression); + } + @Override protected Scope visitValues(Values node, Optional scope) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 6f869cfede7a..3e49e8550f68 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -34,6 +34,7 @@ import io.trino.operator.table.json.JsonTableQueryColumn; import io.trino.operator.table.json.JsonTableValueColumn; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SortOrder; import io.trino.spi.function.table.TableArgument; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; @@ -42,6 +43,7 @@ import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis; import io.trino.sql.analyzer.Analysis.JsonTableAnalysis; +import io.trino.sql.analyzer.Analysis.NearestAnalysis; import io.trino.sql.analyzer.Analysis.TableArgumentAnalysis; import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import io.trino.sql.analyzer.Analysis.UnnestAnalysis; @@ -66,6 +68,7 @@ import io.trino.sql.ir.Row; import io.trino.sql.planner.QueryPlanner.PlanAndMappings; import io.trino.sql.planner.TranslationMap.ParametersRow; +import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.ExceptNode; @@ -85,6 +88,7 @@ import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TopNRankingNode; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UnnestNode; import io.trino.sql.planner.plan.ValuesNode; @@ -124,6 +128,7 @@ import io.trino.sql.tree.Lateral; import io.trino.sql.tree.MeasureDefinition; import io.trino.sql.tree.NaturalJoin; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; @@ -872,12 +877,23 @@ protected RelationPlan visitLateral(Lateral node, Void context) return new RelationPlan(plan.getRoot(), analysis.getScope(node), plan.getFieldMappings(), outerContext); } + @Override + protected RelationPlan visitNearest(Nearest node, Void context) + { + throw semanticException(NOT_SUPPORTED, node, "NEAREST is only supported on the right side of CROSS JOIN, LEFT JOIN, or an implicit join"); + } + @Override protected RelationPlan visitJoin(Join node, Void context) { // TODO: translate the RIGHT join into a mirrored LEFT join when we refactor (@martint) RelationPlan leftPlan = process(node.getLeft(), context); + Optional nearest = getNearest(node.getRight()); + if (nearest.isPresent()) { + return planJoinNearest(node, leftPlan, nearest.get()); + } + Optional unnest = getUnnest(node.getRight()); if (unnest.isPresent()) { return planJoinUnnest(leftPlan, node, unnest.get()); @@ -1220,6 +1236,17 @@ private static Optional getJsonTable(Relation relation) return Optional.empty(); } + private static Optional getNearest(Relation relation) + { + if (relation instanceof AliasedRelation aliasedRelation) { + return getNearest(aliasedRelation.getRelation()); + } + if (relation instanceof Nearest nearest) { + return Optional.of(nearest); + } + return Optional.empty(); + } + private static Optional getLateral(Relation relation) { if (relation instanceof AliasedRelation aliasedRelation) { @@ -1231,6 +1258,109 @@ private static Optional getLateral(Relation relation) return Optional.empty(); } + private RelationPlan planJoinNearest(Join join, RelationPlan leftPlan, Nearest nearest) + { + checkArgument(join.getType() == CROSS || join.getType() == IMPLICIT || join.getType() == Join.Type.INNER || join.getType() == LEFT, "Unsupported join type for NEAREST: %s", join.getType()); + + Symbol uniqueSymbol = symbolAllocator.newSymbol("nearest_left_row", BIGINT); + RelationPlan leftPlanWithId = new RelationPlan( + new AssignUniqueId(idAllocator.getNextId(), leftPlan.getRoot(), uniqueSymbol), + leftPlan.getScope(), + leftPlan.getFieldMappings(), + outerContext); + + RelationPlan rightPlan = process(nearest.getRelation(), null); + List predicates = ImmutableList.builder() + .addAll(nearest.getWhere().stream().toList()) + .add(nearest.getMatch()) + .build(); + + PlanBuilder leftPlanBuilder = newPlanBuilder(leftPlanWithId, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); + PlanBuilder rightPlanBuilder = newPlanBuilder(rightPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); + Analysis.SubqueryAnalysis subqueries = analysis.getSubqueries(nearest); + for (io.trino.sql.tree.Expression predicate : predicates) { + Set dependencies = NamesExtractor.extractNamesNoSubqueries(predicate, analysis.getColumnReferences()); + if (dependencies.stream().allMatch(leftPlan.getScope().getRelationType()::canResolve)) { + leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, predicate, subqueries); + } + else { + // Correlated subqueries in NEAREST predicates are rejected during analysis. + // Any subquery reaching this mixed-predicate path is therefore uncorrelated and can be planned + // on one side before building the combined candidate join, so the rewritten predicate can still + // be attached to the join condition, which matters for LEFT JOIN NEAREST semantics. + rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, predicate, subqueries); + } + } + + List candidateOutputs = ImmutableList.builder() + .addAll(leftPlanWithId.getFieldMappings()) + .addAll(rightPlan.getFieldMappings()) + .build(); + // WHERE and MATCH were analyzed in the NEAREST scope. That scope exposes the FROM relation fields locally + // and the left join input through the parent scope, so rewriting those expressions requires symbol mappings + // for both join sides even though the expression scope itself remains analysis.getScope(nearest). + TranslationMap candidateTranslations = new TranslationMap( + outerContext, + analysis.getScope(nearest), + analysis, + lambdaDeclarationToSymbolMap, + candidateOutputs, + session, + plannerContext) + .withAdditionalMappings(leftPlanBuilder.getTranslations().getMappings()) + .withAdditionalMappings(rightPlanBuilder.getTranslations().getMappings()); + + PlanNode candidateRoot = new JoinNode( + idAllocator.getNextId(), + join.getType() == Join.Type.LEFT ? JoinType.LEFT : JoinType.INNER, + leftPlanBuilder.getRoot(), + rightPlanBuilder.getRoot(), + ImmutableList.of(), + leftPlanBuilder.getRoot().getOutputSymbols(), + rightPlanBuilder.getRoot().getOutputSymbols(), + false, + Optional.of(IrUtils.and(predicates.stream() + .map(expression -> coerceIfNecessary(analysis, expression, candidateTranslations.rewrite(expression))) + .collect(toImmutableList()))), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()); + RelationPlan candidatePlan = new RelationPlan(candidateRoot, analysis.getScope(nearest), candidateOutputs, outerContext); + + NearestAnalysis nearestAnalysis = analysis.getNearest(nearest); + PlanBuilder candidateBuilder = newPlanBuilder(candidatePlan, analysis, lambdaDeclarationToSymbolMap, candidateTranslations.getMappings(), session, plannerContext) + .appendProjections(ImmutableList.of(nearestAnalysis.candidateExpression()), symbolAllocator, idAllocator); + + Symbol orderingSymbol = candidateBuilder.translate(nearestAnalysis.candidateExpression()); + SortOrder sortOrder = switch (nearestAnalysis.operator()) { + case LESS_THAN, LESS_THAN_OR_EQUAL -> SortOrder.DESC_NULLS_LAST; + case GREATER_THAN, GREATER_THAN_OR_EQUAL -> SortOrder.ASC_NULLS_LAST; + default -> throw new IllegalArgumentException("Unsupported NEAREST operator: " + nearestAnalysis.operator()); + }; + PlanNode rankedCandidates = new TopNRankingNode( + idAllocator.getNextId(), + candidateBuilder.getRoot(), + new DataOrganizationSpecification( + ImmutableList.of(uniqueSymbol), + Optional.of(new OrderingScheme(ImmutableList.of(orderingSymbol), ImmutableMap.of(orderingSymbol, sortOrder)))), + TopNRankingNode.RankingType.ROW_NUMBER, + symbolAllocator.newSymbol("nearest_ranking", BIGINT), + 1, + false); + + List outputSymbols = ImmutableList.builder() + .addAll(leftPlan.getFieldMappings()) + .addAll(rightPlan.getFieldMappings()) + .build(); + + return new RelationPlan( + new ProjectNode(idAllocator.getNextId(), rankedCandidates, Assignments.identity(outputSymbols)), + analysis.getScope(join), + outputSymbols, + outerContext); + } + private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Lateral lateral) { PlanBuilder leftPlanBuilder = newPlanBuilder(leftPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 6ee9b4606be9..fad7860e0368 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -190,6 +190,7 @@ import static io.trino.spi.StandardErrorCode.TOO_MANY_ARGUMENTS; import static io.trino.spi.StandardErrorCode.TOO_MANY_GROUPING_SETS; import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; +import static io.trino.spi.StandardErrorCode.UNSUPPORTED_SUBQUERY; import static io.trino.spi.StandardErrorCode.VIEW_IS_RECURSIVE; import static io.trino.spi.StandardErrorCode.VIEW_IS_STALE; import static io.trino.spi.connector.ConnectorMaterializedViewDefinition.WhenStaleBehavior.INLINE; @@ -4360,6 +4361,287 @@ public void testJoinLateral() .hasMessage("line 1:63: FULL JOIN involving LATERAL relation is only supported with condition ON TRUE"); } + @Test + public void testJoinNearest() + { + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts < trades.ts + ) + """); + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + LEFT JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """); + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:03')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts > trades.ts + ) + """); + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + LEFT JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:03')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts >= trades.ts + ) ON TRUE + """); + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + INNER JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """); + analyze(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01', 10)) quotes(symbol, ts, price) + WHERE quotes.price = (SELECT 10) + MATCH quotes.ts <= date_add('second', (SELECT 0), trades.ts) + ) + """); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.ts <= (SELECT trades.ts) + MATCH quotes.ts <= trades.ts + ) + """) + .hasErrorCode(UNSUPPORTED_SUBQUERY) + .hasMessageContaining("Correlated subqueries are not supported in NEAREST WHERE clause"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.ts <= trades.ts + MATCH quotes.ts <= (SELECT trades.ts) + ) + """) + .hasErrorCode(UNSUPPORTED_SUBQUERY) + .hasMessageContaining("Correlated subqueries are not supported in NEAREST MATCH clause"); + analyze(""" + SELECT * + FROM (VALUES (TIMESTAMP '2020-01-01 00:00:02')) trades(ts), + NEAREST ( + FROM (VALUES (TIMESTAMP '2020-01-01 00:00:01')) quotes(ts) + MATCH quotes.ts <= trades.ts + ) + """); + + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + INNER JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON 1 = 1 + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("INNER JOIN involving NEAREST is only supported with condition ON TRUE"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + RIGHT JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:03')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts >= trades.ts + ) ON true + """) + .hasErrorCode(INVALID_COLUMN_REFERENCE) + .hasMessageContaining("LATERAL reference not allowed in RIGHT JOIN"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + FULL JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:03')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts >= trades.ts + ) ON true + """) + .hasErrorCode(INVALID_COLUMN_REFERENCE) + .hasMessageContaining("LATERAL reference not allowed in FULL JOIN"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + LEFT JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON 1 = 1 + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("LEFT JOIN involving NEAREST is only supported with condition ON TRUE"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + LEFT JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) USING (symbol) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("JOIN USING involving NEAREST is not supported"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + NATURAL JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + MATCH quotes.ts <= trades.ts + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("Natural join not supported"); + assertFails(""" + SELECT * + FROM NEAREST ( + FROM (VALUES (TIMESTAMP '2020-01-01 00:00:01')) quotes(ts) + MATCH quotes.ts <= TIMESTAMP '2020-01-01 00:00:02' + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("NEAREST is only supported on the right side of CROSS JOIN, INNER JOIN, LEFT JOIN, or an implicit join"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts = trades.ts + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("NEAREST MATCH clause must use <, <=, >, or >="); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <> trades.ts + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("NEAREST MATCH clause must use <, <=, >, or >="); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts < trades.ts AND quotes.symbol = trades.symbol + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("NEAREST MATCH clause must be a comparison expression"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts < quotes.ts + ) + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessageContaining("NEAREST MATCH clause must compare one FROM relation expression with one non-FROM expression"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE 1 + MATCH quotes.ts <= trades.ts + ) + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessageContaining("NEAREST WHERE clause must evaluate to a boolean"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + MATCH 1 + ) + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessageContaining("NEAREST MATCH clause must evaluate to a boolean"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE sum(1) = 1 + MATCH quotes.ts <= trades.ts + ) + """) + .hasErrorCode(EXPRESSION_NOT_SCALAR) + .hasMessageContaining("NEAREST WHERE clause cannot contain aggregations, window functions or grouping operations"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + WHERE grouping(trades.symbol) = 0 + MATCH quotes.ts <= trades.ts + ) + """) + .hasErrorCode(EXPRESSION_NOT_SCALAR) + .hasMessageContaining("NEAREST WHERE clause cannot contain aggregations, window functions or grouping operations"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + MATCH row_number() OVER () = 1 + ) + """) + .hasErrorCode(EXPRESSION_NOT_SCALAR) + .hasMessageContaining("NEAREST MATCH clause cannot contain aggregations, window functions or grouping operations"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) + MATCH quotes.ts <= grouping(trades.symbol) + ) + """) + .hasErrorCode(EXPRESSION_NOT_SCALAR) + .hasMessageContaining("NEAREST MATCH clause cannot contain aggregations, window functions or grouping operations"); + assertFails(""" + SELECT * + FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')) trades(symbol, ts) + CROSS JOIN NEAREST ( + FROM (SELECT * FROM (VALUES ('A', TIMESTAMP '2020-01-01 00:00:01')) quotes(symbol, ts) WHERE quotes.symbol = trades.symbol) + MATCH ts <= trades.ts + ) + """) + .hasErrorCode(COLUMN_NOT_FOUND) + .hasMessageContaining("Column 'trades.symbol' cannot be resolved"); + } + @Test public void testNullTreatment() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index a46657a5726e..6311cc7d7c20 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -952,6 +952,91 @@ public void testCorrelatedJoinWithTopN() anyTree(tableScan("nation", ImmutableMap.of("nation_name", "name", "nation_regionkey", "regionkey"))))))))); } + @Test + public void testInnerJoinNearest() + { + assertPlan( + """ + SELECT region.regionkey, nation.name + FROM region + INNER JOIN NEAREST ( + FROM nation + WHERE nation.regionkey = region.regionkey + MATCH nation.nationkey < region.regionkey + ) ON TRUE + """, + output( + project( + ImmutableMap.of( + "regionkey", expression(new Reference(BIGINT, "regionkey")), + "name", expression(new Reference(VARCHAR, "name_1"))), + topNRanking( + pattern -> pattern + .specification( + ImmutableList.of("nearest_left_row"), + ImmutableList.of("nationkey"), + ImmutableMap.of("nationkey", SortOrder.DESC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(false), + join(INNER, builder -> builder + .equiCriteria("regionkey", "regionkey_2") + .distributionType(REPLICATED) + .left(assignUniqueId( + "nearest_left_row", + any(tableScan("region", ImmutableMap.of("regionkey", "regionkey"))))) + .right(exchange( + LOCAL, + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "nationkey"), new Reference(BIGINT, "regionkey_2")), + tableScan("nation", ImmutableMap.of( + "nationkey", "nationkey", + "name_1", "name", + "regionkey_2", "regionkey")))))))))); + } + + @Test + public void testLeftJoinNearest() + { + assertPlan( + """ + SELECT region.regionkey, nation.name + FROM region + LEFT JOIN NEAREST ( + FROM nation + WHERE nation.regionkey = region.regionkey + MATCH nation.nationkey > region.regionkey + ) ON TRUE + """, + output( + project( + ImmutableMap.of( + "regionkey", expression(new Reference(BIGINT, "regionkey")), + "name", expression(new Reference(VARCHAR, "name_1"))), + topNRanking( + pattern -> pattern + .specification( + ImmutableList.of("nearest_left_row"), + ImmutableList.of("nationkey"), + ImmutableMap.of("nationkey", ASC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(false), + join(LEFT, builder -> builder + .equiCriteria("regionkey", "regionkey_2") + .left(assignUniqueId( + "nearest_left_row", + tableScan("region", ImmutableMap.of("regionkey", "regionkey")))) + .right(exchange( + LOCAL, + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "nationkey"), new Reference(BIGINT, "regionkey_2")), + tableScan("nation", ImmutableMap.of( + "nationkey", "nationkey", + "name_1", "name", + "regionkey_2", "regionkey")))))))))); + } + @Test public void testCorrelatedJoinWithNullCondition() { diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestNearest.java b/core/trino-main/src/test/java/io/trino/sql/query/TestNearest.java new file mode 100644 index 000000000000..17bde5dfdc1a --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestNearest.java @@ -0,0 +1,558 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestNearest +{ + private final QueryAssertions assertions = new QueryAssertions(); + + @AfterAll + public void teardown() + { + assertions.close(); + } + + @Test + public void testCrossJoinNearest() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:03')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20)) + SELECT trades.symbol, trades.ts, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES " + + "('A', TIMESTAMP '2020-01-01 00:00:02', 10), " + + "('A', TIMESTAMP '2020-01-01 00:00:04', 11), " + + "('B', TIMESTAMP '2020-01-01 00:00:03', 20)"); + } + + @Test + public void testImplicitJoinNearest() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:03')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20)) + SELECT trades.symbol, trades.ts, quotes.price + FROM trades, + NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES " + + "('A', TIMESTAMP '2020-01-01 00:00:02', 10), " + + "('A', TIMESTAMP '2020-01-01 00:00:04', 11), " + + "('B', TIMESTAMP '2020-01-01 00:00:03', 20)"); + } + + @Test + public void testInnerJoinNearest() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:03')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20)) + SELECT trades.symbol, trades.ts, quotes.price + FROM trades + INNER JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .matches("VALUES " + + "('A', TIMESTAMP '2020-01-01 00:00:02', 10), " + + "('A', TIMESTAMP '2020-01-01 00:00:04', 11), " + + "('B', TIMESTAMP '2020-01-01 00:00:03', 20)"); + } + + @Test + public void testLeftJoinNearest() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('C', TIMESTAMP '2020-01-01 00:00:01')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10)) + SELECT trades.symbol, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts < trades.ts + ) ON TRUE + """)) + .matches("VALUES ('A', 10), ('C', CAST(null AS integer))"); + } + + @Test + public void testCrossJoinNearestForward() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:01')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('A', TIMESTAMP '2020-01-01 00:00:05', 12), + ('B', TIMESTAMP '2020-01-01 00:00:01', 20)) + SELECT trades.symbol, trades.ts, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts > trades.ts + ) + """)) + .matches("VALUES " + + "('A', TIMESTAMP '2020-01-01 00:00:02', 11), " + + "('A', TIMESTAMP '2020-01-01 00:00:04', 12)"); + } + + @Test + public void testLeftJoinNearestForward() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03'), + ('B', TIMESTAMP '2020-01-01 00:00:01'), + ('C', TIMESTAMP '2020-01-01 00:00:02')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03', 10), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20)) + SELECT trades.symbol, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts >= trades.ts + ) ON TRUE + """)) + .matches("VALUES ('A', 10), ('B', 20), ('C', CAST(null AS integer))"); + } + + @Test + public void testCrossJoinNearestWithConstantAnchor() + { + assertThat(assertions.query( + """ + WITH + trades(id) AS ( + VALUES + 1, + 2), + quotes(ts, price) AS ( + VALUES + (TIMESTAMP '2020-01-01 00:00:01', 10), + (TIMESTAMP '2020-01-01 00:00:03', 11)) + SELECT trades.id, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + MATCH quotes.ts <= TIMESTAMP '2020-01-01 00:00:02' + ) + """)) + .matches("VALUES (1, 10), (2, 10)"); + } + + @Test + public void testImplicitJoinNearestWithConstantAnchor() + { + assertThat(assertions.query( + """ + WITH + trades(id) AS ( + VALUES + 1, + 2), + quotes(ts, price) AS ( + VALUES + (TIMESTAMP '2020-01-01 00:00:01', 10), + (TIMESTAMP '2020-01-01 00:00:03', 11)) + SELECT trades.id, quotes.price + FROM trades, + NEAREST ( + FROM quotes + MATCH quotes.ts <= TIMESTAMP '2020-01-01 00:00:02' + ) + """)) + .matches("VALUES (1, 10), (2, 10)"); + } + + @Test + public void testLeftJoinNearestWithNullMatchKeys() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('B', CAST(NULL AS timestamp)), + ('C', TIMESTAMP '2020-01-01 00:00:01')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('B', TIMESTAMP '2020-01-01 00:00:01', 20), + ('C', CAST(NULL AS timestamp), 30)) + SELECT trades.symbol, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .matches("VALUES ('A', 10), ('B', CAST(null AS integer)), ('C', CAST(null AS integer))"); + } + + @Test + public void testLeftJoinNearestWithNullOuterMatchKey() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('B', CAST(NULL AS timestamp))), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('B', TIMESTAMP '2020-01-01 00:00:01', 20)) + SELECT trades.symbol, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .matches("VALUES ('A', 10), ('B', CAST(null AS integer))"); + } + + @Test + public void testLeftJoinNearestWithNullCandidateMatchKey() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('B', TIMESTAMP '2020-01-01 00:00:02')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', CAST(NULL AS timestamp), 99), + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('B', CAST(NULL AS timestamp), 20)) + SELECT trades.symbol, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .matches("VALUES ('A', 10), ('B', CAST(null AS integer))"); + } + + @Test + public void testCrossJoinNearestWithNullMatchKeys() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('B', CAST(NULL AS timestamp)), + ('C', TIMESTAMP '2020-01-01 00:00:02')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('B', TIMESTAMP '2020-01-01 00:00:01', 20), + ('C', CAST(NULL AS timestamp), 30)) + SELECT trades.symbol, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES ('A', 10)"); + } + + @Test + public void testCrossJoinNearestWithNullWherePredicate() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + (CAST(NULL AS varchar), TIMESTAMP '2020-01-01 00:00:02')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + (CAST(NULL AS varchar), TIMESTAMP '2020-01-01 00:00:01', 20)) + SELECT trades.symbol, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES (CAST('A' AS varchar), 10)"); + } + + @Test + public void testCrossJoinNearestWithoutWhere() + { + assertThat(assertions.query( + """ + WITH + trades(id, ts) AS ( + VALUES + (1, TIMESTAMP '2020-01-01 00:00:02'), + (2, TIMESTAMP '2020-01-01 00:00:04')), + quotes(ts, price) AS ( + VALUES + (TIMESTAMP '2020-01-01 00:00:01', 10), + (TIMESTAMP '2020-01-01 00:00:03', 11)) + SELECT trades.id, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES (1, 10), (2, 11)"); + } + + @Test + public void testLeftJoinNearestWithoutWhere() + { + assertThat(assertions.query( + """ + WITH + trades(id, ts) AS ( + VALUES + (1, TIMESTAMP '2020-01-01 00:00:00'), + (2, TIMESTAMP '2020-01-01 00:00:04')), + quotes(ts, price) AS ( + VALUES + (TIMESTAMP '2020-01-01 00:00:01', 10), + (TIMESTAMP '2020-01-01 00:00:03', 11)) + SELECT trades.id, quotes.price + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .matches("VALUES (1, CAST(null AS integer)), (2, 11)"); + } + + @Test + public void testCrossJoinNearestWithDuplicateCandidates() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:03')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20), + ('B', TIMESTAMP '2020-01-01 00:00:02', 20)) + SELECT trades.symbol, quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES ('A', 11), ('B', 20)"); + } + + @Test + public void testCrossJoinNearestWithSubqueriesInPredicates() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:02'), + ('A', TIMESTAMP '2020-01-01 00:00:04')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:01', 10), + ('A', TIMESTAMP '2020-01-01 00:00:03', 11)) + SELECT quotes.price + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + AND quotes.price >= (SELECT 10) + MATCH quotes.ts <= date_add('second', (SELECT 0), trades.ts) + ) + """)) + .matches("VALUES 10, 11"); + } + + @Test + public void testCrossJoinNearestWithTies() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:04')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('A', TIMESTAMP '2020-01-01 00:00:03', 12), + ('A', TIMESTAMP '2020-01-01 00:00:01', 10)) + SELECT count(*), min(quotes.ts), max(quotes.ts) + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .matches("VALUES (BIGINT '1', TIMESTAMP '2020-01-01 00:00:03', TIMESTAMP '2020-01-01 00:00:03')"); + } + + @Test + public void testLeftJoinNearestWithTies() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:04'), + ('B', TIMESTAMP '2020-01-01 00:00:01')), + quotes(symbol, ts, price) AS ( + VALUES + ('A', TIMESTAMP '2020-01-01 00:00:03', 11), + ('A', TIMESTAMP '2020-01-01 00:00:03', 12)) + SELECT trades.symbol, count(quotes.price), min(quotes.ts), max(quotes.ts) + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + GROUP BY trades.symbol + """)) + .matches("VALUES " + + "('A', BIGINT '1', CAST(TIMESTAMP '2020-01-01 00:00:03' AS timestamp(0)), CAST(TIMESTAMP '2020-01-01 00:00:03' AS timestamp(0))), " + + "('B', BIGINT '0', CAST(null AS timestamp(0)), CAST(null AS timestamp(0)))"); + } + + @Test + public void testJoinNearestAliasing() + { + assertThat(assertions.query( + """ + WITH + trades(symbol, ts) AS ( + VALUES ('A', TIMESTAMP '2020-01-01 00:00:02')), + quotes(symbol, ts, price) AS ( + VALUES ('A', TIMESTAMP '2020-01-01 00:00:01', 10)) + SELECT t.symbol, q_price + FROM trades t + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = t.symbol + MATCH quotes.ts <= t.ts + ) q(q_symbol, q_ts, q_price) + """)) + .matches("VALUES ('A', 10)"); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 5c7a07ece9f2..da9cede3b3e3 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -102,6 +102,7 @@ import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; import io.trino.sql.tree.NullInputCharacteristic; @@ -475,6 +476,22 @@ protected Void visitLateral(Lateral node, Integer indent) return null; } + @Override + protected Void visitNearest(Nearest node, Integer indent) + { + append(indent, "NEAREST ("); + append(indent + 1, "FROM "); + process(node.getRelation(), indent + 1); + node.getWhere().ifPresent(where -> append(indent + 1, "WHERE ") + .append(formatExpression(where)) + .append('\n')); + append(indent + 1, "MATCH ") + .append(formatExpression(node.getMatch())) + .append('\n'); + append(indent, ")"); + return null; + } + @Override protected Void visitTableFunctionInvocation(TableFunctionInvocation node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index ef61c0f44487..ece14144a148 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -186,6 +186,7 @@ import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; @@ -2154,6 +2155,16 @@ public Node visitLateral(SqlBaseParser.LateralContext context) return new Lateral(getLocation(context), (Query) visit(context.query())); } + @Override + public Node visitNearest(SqlBaseParser.NearestContext context) + { + return new Nearest( + getLocation(context), + (Relation) visit(context.relation()), + visitIfPresent(context.where, Expression.class), + (Expression) visit(context.match)); + } + @Override public Node visitTableFunctionInvocation(SqlBaseParser.TableFunctionInvocationContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 09b8f152d417..a0d0f4acaa70 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -502,6 +502,11 @@ protected R visitLateral(Lateral node, C context) return visitRelation(node, context); } + protected R visitNearest(Nearest node, C context) + { + return visitRelation(node, context); + } + protected R visitValues(Values node, C context) { return visitQueryBody(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index c396ae0aed4f..74796e31a16f 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -610,6 +610,15 @@ protected Void visitUnnest(Unnest node, C context) return null; } + @Override + protected Void visitNearest(Nearest node, C context) + { + process(node.getRelation(), context); + node.getWhere().ifPresent(expression -> process(expression, context)); + process(node.getMatch(), context); + return null; + } + @Override protected Void visitGroupBy(GroupBy node, C context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Nearest.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Nearest.java new file mode 100644 index 000000000000..c28df6412491 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Nearest.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class Nearest + extends Relation +{ + private final Relation relation; + private final Optional where; + private final Expression match; + + public Nearest(NodeLocation location, Relation relation, Optional where, Expression match) + { + super(location); + this.relation = requireNonNull(relation, "relation is null"); + this.where = requireNonNull(where, "where is null"); + this.match = requireNonNull(match, "match is null"); + } + + public Relation getRelation() + { + return relation; + } + + public Optional getWhere() + { + return where; + } + + public Expression getMatch() + { + return match; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitNearest(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(relation); + where.ifPresent(children::add); + children.add(match); + return children.build(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("relation", relation) + .add("where", where.orElse(null)) + .add("match", match) + .omitNullValues() + .toString(); + } + + @Override + public int hashCode() + { + return Objects.hash(relation, where, match); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Nearest other = (Nearest) obj; + return Objects.equals(relation, other.relation) && + Objects.equals(where, other.where) && + Objects.equals(match, other.match); + } + + @Override + public boolean shallowEquals(Node other) + { + return sameClass(this, other); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Relation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Relation.java index 91c2bf8fa096..827e91d0ef4d 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Relation.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Relation.java @@ -18,6 +18,12 @@ public abstract class Relation extends Node { + protected Relation(NodeLocation location) + { + super(location); + } + + @Deprecated protected Relation(Optional location) { super(location); diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 39e2a7ed1e94..1352af1eb7be 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -135,6 +135,7 @@ import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; +import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; @@ -4982,6 +4983,187 @@ public void testLateral() Optional.of(new JoinOn(BooleanLiteral.TRUE_LITERAL))))); } + @Test + public void testNearest() + { + assertThat(statement( + """ + SELECT * + FROM trades + CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .isEqualTo(new Query( + location(1, 1), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + new QuerySpecification( + location(1, 1), + new Select(location(1, 1), false, ImmutableList.of(new AllColumns(location(1, 8)))), + Optional.of(new Join( + location(2, 6), + Join.Type.CROSS, + new Table(location(2, 6), qualifiedName(location(2, 6), "trades")), + new Nearest( + location(3, 12), + new Table(location(4, 10), qualifiedName(location(4, 10), "quotes")), + Optional.of(new ComparisonExpression( + location(5, 25), + ComparisonExpression.Operator.EQUAL, + new DereferenceExpression( + location(5, 11), + new Identifier(location(5, 11), "quotes", false), + new Identifier(location(5, 18), "symbol", false)), + new DereferenceExpression( + location(5, 27), + new Identifier(location(5, 27), "trades", false), + new Identifier(location(5, 34), "symbol", false)))), + new ComparisonExpression( + location(6, 21), + ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, + new DereferenceExpression( + location(6, 11), + new Identifier(location(6, 11), "quotes", false), + new Identifier(location(6, 18), "ts", false)), + new DereferenceExpression( + location(6, 24), + new Identifier(location(6, 24), "trades", false), + new Identifier(location(6, 31), "ts", false)))), + Optional.empty())), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty())); + + assertThat(statement( + """ + SELECT * + FROM trades, + NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) + """)) + .isEqualTo(new Query( + location(1, 1), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + new QuerySpecification( + location(1, 1), + new Select(location(1, 1), false, ImmutableList.of(new AllColumns(location(1, 8)))), + Optional.of(new Join( + location(1, 1), + Join.Type.IMPLICIT, + new Table(location(2, 6), qualifiedName(location(2, 6), "trades")), + new Nearest( + location(3, 6), + new Table(location(4, 15), qualifiedName(location(4, 15), "quotes")), + Optional.of(new ComparisonExpression( + location(5, 30), + ComparisonExpression.Operator.EQUAL, + new DereferenceExpression( + location(5, 16), + new Identifier(location(5, 16), "quotes", false), + new Identifier(location(5, 23), "symbol", false)), + new DereferenceExpression( + location(5, 32), + new Identifier(location(5, 32), "trades", false), + new Identifier(location(5, 39), "symbol", false)))), + new ComparisonExpression( + location(6, 26), + ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, + new DereferenceExpression( + location(6, 16), + new Identifier(location(6, 16), "quotes", false), + new Identifier(location(6, 23), "ts", false)), + new DereferenceExpression( + location(6, 29), + new Identifier(location(6, 29), "trades", false), + new Identifier(location(6, 36), "ts", false)))), + Optional.empty())), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty())); + + assertThat(statement( + """ + SELECT * + FROM trades + LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts + ) ON TRUE + """)) + .isEqualTo(new Query( + location(1, 1), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + new QuerySpecification( + location(1, 1), + new Select(location(1, 1), false, ImmutableList.of(new AllColumns(location(1, 8)))), + Optional.of(new Join( + location(2, 6), + Join.Type.LEFT, + new Table(location(2, 6), qualifiedName(location(2, 6), "trades")), + new Nearest( + location(3, 11), + new Table(location(4, 10), qualifiedName(location(4, 10), "quotes")), + Optional.of(new ComparisonExpression( + location(5, 25), + ComparisonExpression.Operator.EQUAL, + new DereferenceExpression( + location(5, 11), + new Identifier(location(5, 11), "quotes", false), + new Identifier(location(5, 18), "symbol", false)), + new DereferenceExpression( + location(5, 27), + new Identifier(location(5, 27), "trades", false), + new Identifier(location(5, 34), "symbol", false)))), + new ComparisonExpression( + location(6, 21), + ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, + new DereferenceExpression( + location(6, 11), + new Identifier(location(6, 11), "quotes", false), + new Identifier(location(6, 18), "ts", false)), + new DereferenceExpression( + location(6, 24), + new Identifier(location(6, 24), "trades", false), + new Identifier(location(6, 31), "ts", false)))), + Optional.of(new JoinOn(new BooleanLiteral(location(7, 6), "TRUE"))))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } + @Test public void testStartTransaction() { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index c9e0f2e6dd2f..0db095a19a3d 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -48,7 +48,7 @@ private static Stream statements() Arguments.of("select * from foo where @what", "line 1:25: mismatched input '@'. Expecting: "), Arguments.of("select * from 'oops", - "line 1:15: mismatched input '''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'TABLE', 'UNNEST', "), + "line 1:15: mismatched input '''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'NEAREST', 'TABLE', 'UNNEST', "), Arguments.of("select *\nfrom x\nfrom", "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', " + "'LIMIT', 'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', , "), @@ -57,9 +57,9 @@ private static Stream statements() Arguments.of("select ", "line 1:8: mismatched input ''. Expecting: '*', 'ALL', 'DISTINCT', "), Arguments.of("select * from", - "line 1:14: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'TABLE', 'UNNEST', "), + "line 1:14: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'NEAREST', 'TABLE', 'UNNEST', "), Arguments.of("select * from ", - "line 1:16: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'TABLE', 'UNNEST', "), + "line 1:16: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'NEAREST', 'TABLE', 'UNNEST', "), Arguments.of("select * from `foo`", "line 1:15: backquoted identifiers are not supported; use double quotes to quote identifiers"), Arguments.of("select * from foo `bar`", @@ -113,7 +113,7 @@ private static Stream statements() Arguments.of("CREATE TABLE t (x bigint) COMMENT ", "line 1:35: mismatched input ''. Expecting: "), Arguments.of("SELECT * FROM ( ", - "line 1:17: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'TABLE', 'UNNEST', , "), + "line 1:17: mismatched input ''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'NEAREST', 'TABLE', 'UNNEST', , "), Arguments.of("SELECT CAST(a AS )", "line 1:18: mismatched input ')'. Expecting: "), Arguments.of("SELECT CAST(a AS decimal()", diff --git a/docs/src/main/sphinx/sql/select.md b/docs/src/main/sphinx/sql/select.md index b50e66d29b8f..84affb76b98e 100644 --- a/docs/src/main/sphinx/sql/select.md +++ b/docs/src/main/sphinx/sql/select.md @@ -1469,6 +1469,131 @@ CROSS JOIN LATERAL (SELECT name || ' :-' AS x) CROSS JOIN LATERAL (SELECT x || ')' AS y); ``` +When `LATERAL` appears on the right side of a `FULL JOIN`, the only condition +supported by the current implementation is `ON TRUE`. + +### NEAREST + +`NEAREST` is a relation that selects at most one row from a `FROM` relation for +each row on the left side of a join. + +Use `NEAREST` on the right side of an explicit `CROSS JOIN`, `INNER JOIN` with +`ON TRUE`, `LEFT JOIN` with `ON TRUE`, or an implicit comma join: + +```text +CROSS JOIN NEAREST ( + FROM relation + [ WHERE condition ] + MATCH comparison +) + +INNER JOIN NEAREST ( + FROM relation + [ WHERE condition ] + MATCH comparison +) ON TRUE + +relation, +NEAREST ( + FROM relation + [ WHERE condition ] + MATCH comparison +) + +LEFT JOIN NEAREST ( + FROM relation + [ WHERE condition ] + MATCH comparison + ) ON TRUE +``` + +The `MATCH` clause is required. It must be a single comparison using one +expression from the `FROM` relation and one non-`FROM` expression, with one of +the operators `<`, `<=`, `>`, or `>=`. + +The comparison determines both the matching direction and the ordering of +candidate rows: + +- `<` and `<=` select the closest row from the `FROM` relation whose match key + is smaller than, or smaller than or equal to, the other expression. +- `>` and `>=` select the closest row from the `FROM` relation whose match key + is greater than, or greater than or equal to, the other expression. + +The optional `WHERE` clause filters candidate rows before the nearest row is +selected. + +`NEAREST` can be understood as shorthand for a lateral subquery that filters to +matching candidate rows, orders them by the `FROM` relation match key, and +keeps only the first row. For example: + +- `NEAREST ( + FROM ... + WHERE ... + MATCH right_key <= left_key + )` is equivalent to: + + ```sql + LATERAL ( + SELECT * + FROM ... + WHERE ... + AND right_key <= left_key + ORDER BY right_key DESC + FETCH FIRST 1 ROW ONLY + ) + ``` + +- `NEAREST ( + FROM ... + WHERE ... + MATCH right_key >= left_key + )` is equivalent to: + + ```sql + LATERAL ( + SELECT * + FROM ... + WHERE ... + AND right_key >= left_key + ORDER BY right_key ASC + FETCH FIRST 1 ROW ONLY + ) + ``` + +More generally, `<` and `<=` order the `FROM` relation match key descending, +while `>` and `>=` order it ascending. + +For example, the following query matches each trade with the most recent quote +for the same symbol: + +```sql +SELECT trades.symbol, trades.ts, quotes.price +FROM trades +CROSS JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts <= trades.ts +); +``` + +To preserve rows from the left side when no nearest row exists, use +`LEFT JOIN NEAREST`: + +```sql +SELECT trades.symbol, quotes.price +FROM trades +LEFT JOIN NEAREST ( + FROM quotes + WHERE quotes.symbol = trades.symbol + MATCH quotes.ts >= trades.ts +) ON TRUE; +``` + +The current implementation supports `CROSS JOIN NEAREST (...)`, `INNER JOIN +NEAREST (...) ON TRUE`, implicit comma joins with `NEAREST (...)`, and +`LEFT JOIN NEAREST (...) ON TRUE`. +`JOIN USING`, `NATURAL JOIN`, and join conditions other than `ON TRUE` are not +supported for `NEAREST`. ### Qualifying column names When two relations in a join have columns with the same name, the column