diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 04d3cb92b4ed..6ba901446c24 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -271,9 +271,9 @@ rowCount ; queryTerm - : queryPrimary #queryTermDefault - | left=queryTerm operator=INTERSECT setQuantifier? right=queryTerm #setOperation - | left=queryTerm operator=(UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + : queryPrimary #queryTermDefault + | left=queryTerm operator=INTERSECT setQuantifier? corresponding? right=queryTerm #setOperation + | left=queryTerm operator=(UNION | EXCEPT) setQuantifier? corresponding? right=queryTerm #setOperation ; queryPrimary @@ -283,6 +283,10 @@ queryPrimary | '(' queryNoWith ')' #subquery ; +corresponding + : CORRESPONDING (BY columnAliases)? + ; + sortItem : expression ordering=(ASC | DESC)? (NULLS nullOrdering=(FIRST | LAST))? ; @@ -1001,7 +1005,7 @@ nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved : ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION | BEGIN | BERNOULLI | BOTH - | CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT + | CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | CORRESPONDING | COUNT | CURRENT | DATA | DATE | DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETERMINISTIC | DISTRIBUTED | DO | DOUBLE | ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXECUTE | EXPLAIN | FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS @@ -1062,6 +1066,7 @@ CONDITIONAL: 'CONDITIONAL'; CONSTRAINT: 'CONSTRAINT'; COUNT: 'COUNT'; COPARTITION: 'COPARTITION'; +CORRESPONDING: 'CORRESPONDING'; CREATE: 'CREATE'; CROSS: 'CROSS'; CUBE: 'CUBE'; diff --git a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java index 5249453a081e..17a149f46c4d 100644 --- a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java +++ b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java @@ -60,6 +60,7 @@ public void test() "CONDITIONAL", "CONSTRAINT", "COPARTITION", + "CORRESPONDING", "COUNT", "CREATE", "CROSS", diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 47a117bda563..bd6d9adc879d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -224,6 +224,7 @@ public class Analysis private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); private final Map columns = new LinkedHashMap<>(); + private final Map, CorrespondingAnalysis> correspondingAnalysis = new LinkedHashMap<>(); private final Map, Double> sampleRatios = new LinkedHashMap<>(); @@ -767,6 +768,16 @@ public ColumnHandle getColumn(Field field) return columns.get(field); } + public CorrespondingAnalysis getCorrespondingAnalysis(Node node) + { + return correspondingAnalysis.get(NodeRef.of(node)); + } + + public void setCorrespondingAnalysis(Node node, CorrespondingAnalysis correspondingAnalysis) + { + this.correspondingAnalysis.put(NodeRef.of(node), correspondingAnalysis); + } + public Optional getAnalyzeMetadata() { return analyzeMetadata; @@ -2547,4 +2558,14 @@ public record JsonTableAnalysis( requireNonNull(orderedOutputColumns, "orderedOutputColumns is null"); } } + + public record CorrespondingAnalysis(List indexes, List fields) + { + public CorrespondingAnalysis + { + indexes = ImmutableList.copyOf(indexes); + fields = ImmutableList.copyOf(fields); + checkArgument(indexes.size() == fields.size(), "indexes and fields must have the same size"); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 33de0fcc09c6..106431aeb3d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -105,6 +105,7 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; import io.trino.sql.analyzer.Analysis.JsonTableAnalysis; import io.trino.sql.analyzer.Analysis.MergeAnalysis; @@ -133,6 +134,7 @@ import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; +import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateSchema; @@ -282,6 +284,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -3191,9 +3194,54 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) { checkState(node.getRelations().size() >= 2); - List childrenTypes = node.getRelations().stream() - .map(relation -> process(relation, scope).getRelationType().withOnlyVisibleFields()) - .collect(toImmutableList()); + List relations = node.getRelations(); + + List childrenTypes = new ArrayList<>(); + if (node.getCorresponding().isPresent()) { + checkState(relations.size() == 2, "CORRESPONDING requires 2 relations"); + + Corresponding corresponding = node.getCorresponding().get(); + if (!corresponding.getColumns().isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "CORRESPONDING with columns is unsupported"); + } + + RelationType leftRelation = process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields(); + RelationType rightRelation = process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields(); + + Map leftFieldsByName = buildNameToIndex(node, leftRelation); + Map rightFieldsByName = buildNameToIndex(node, rightRelation); + + List correspondingColumns = leftFieldsByName.keySet().stream() + .filter(rightFieldsByName::containsKey) + .collect(toImmutableList()); + + if (correspondingColumns.isEmpty()) { + throw semanticException(MISMATCHED_COLUMN_ALIASES, node, "No corresponding columns"); + } + + ImmutableList.Builder leftColumnIndexes = ImmutableList.builderWithExpectedSize(correspondingColumns.size()); + ImmutableList.Builder rightColumnIndexes = ImmutableList.builderWithExpectedSize(correspondingColumns.size()); + ImmutableList.Builder leftRequiredFields = ImmutableList.builderWithExpectedSize(correspondingColumns.size()); + ImmutableList.Builder rightRequiredFields = ImmutableList.builderWithExpectedSize(correspondingColumns.size()); + for (String correspondingColumn : correspondingColumns) { + int leftFieldIndex = leftFieldsByName.get(correspondingColumn); + int rightFieldIndex = rightFieldsByName.get(correspondingColumn); + leftColumnIndexes.add(leftFieldIndex); + rightColumnIndexes.add(rightFieldIndex); + leftRequiredFields.add(leftRelation.getFieldByIndex(leftFieldIndex)); + rightRequiredFields.add(rightRelation.getFieldByIndex(rightFieldIndex)); + } + + analysis.setCorrespondingAnalysis(node.getRelations().getFirst(), new CorrespondingAnalysis(leftColumnIndexes.build(), leftRequiredFields.build())); + analysis.setCorrespondingAnalysis(node.getRelations().getLast(), new CorrespondingAnalysis(rightColumnIndexes.build(), rightRequiredFields.build())); + + childrenTypes.add(new RelationType(leftRequiredFields.build()).withOnlyVisibleFields()); + childrenTypes.add(new RelationType(rightRequiredFields.build()).withOnlyVisibleFields()); + } + else { + childrenTypes.add(process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields()); + childrenTypes.add(process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields()); + } String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH); Type[] outputFieldTypes = childrenTypes.get(0).getVisibleFields().stream() @@ -3264,8 +3312,8 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) .collect(toImmutableSet())); } - for (int i = 0; i < node.getRelations().size(); i++) { - Relation relation = node.getRelations().get(i); + for (int i = 0; i < relations.size(); i++) { + Relation relation = relations.get(i); RelationType relationType = childrenTypes.get(i); for (int j = 0; j < relationType.getVisibleFields().size(); j++) { Type outputFieldType = outputFieldTypes[j]; @@ -3279,6 +3327,22 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) return createAndAssignScope(node, scope, outputDescriptorFields); } + private static Map buildNameToIndex(SetOperation node, RelationType relationType) + { + Map nameToIndex = new LinkedHashMap<>(); + for (int i = 0; i < relationType.getAllFieldCount(); i++) { + Field field = relationType.getFieldByIndex(i); + String name = field.getName() + .orElseThrow(() -> semanticException(MISSING_COLUMN_NAME, node, "Anonymous columns are not allowed in set operations with CORRESPONDING")) + // TODO https://github.com/trinodb/trino/issues/17 Add support for case sensitive identifiers + .toLowerCase(ENGLISH); + if (nameToIndex.put(name, i) != null) { + throw semanticException(AMBIGUOUS_NAME, node, "Duplicate columns found when using CORRESPONDING in set operations: %s", name); + } + } + return ImmutableMap.copyOf(nameToIndex); + } + @Override protected Scope visitJoin(Join node, Optional scope) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 2a6610c800d3..3e0ccfebc172 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -40,6 +40,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis; import io.trino.sql.analyzer.Analysis.JsonTableAnalysis; import io.trino.sql.analyzer.Analysis.TableArgumentAnalysis; import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; @@ -1865,9 +1866,33 @@ private SetOperationPlan process(SetOperation node) ImmutableListMultimap.Builder symbolMapping = ImmutableListMultimap.builder(); ImmutableList.Builder sources = ImmutableList.builder(); - for (Relation child : node.getRelations()) { + List relations = node.getRelations(); + checkArgument(relations.size() == 2, "relations size must be 2"); + for (Relation child : relations) { RelationPlan plan = process(child, null); + if (node.getCorresponding().isPresent()) { + int[] fieldIndexForVisibleColumn = new int[plan.getDescriptor().getVisibleFieldCount()]; + int visibleColumn = 0; + for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { + if (!plan.getDescriptor().getFieldByIndex(i).isHidden()) { + fieldIndexForVisibleColumn[visibleColumn] = i; + visibleColumn++; + } + } + + CorrespondingAnalysis correspondingAnalysis = analysis.getCorrespondingAnalysis(child); + List requiredColumns = correspondingAnalysis.indexes().stream() + .filter(column -> column < fieldIndexForVisibleColumn.length) + .map(column -> fieldIndexForVisibleColumn[column]) + .map(plan::getSymbol) + .collect(toImmutableList()); + + ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), plan.getRoot(), Assignments.identity(requiredColumns)); + Scope scope = Scope.builder().withRelationType(plan.getScope().getRelationId(), new RelationType(correspondingAnalysis.fields())).build(); + plan = new RelationPlan(projectNode, scope, requiredColumns, plan.getOuterContext()); + } + NodeAndMappings planAndMappings; List types = analysis.getRelationCoercion(child); if (types == null) { diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSetOperations.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSetOperations.java index 96df390fbb79..8d271ee6c017 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSetOperations.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSetOperations.java @@ -293,4 +293,206 @@ public void testIntersectWithEmptyBranches() .describedAs("INTERSECT DISTINCT with empty branches") .returnsEmptyResult(); } + + @Test + void testExceptCorresponding() + { + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y) + EXCEPT CORRESPONDING + SELECT * FROM (VALUES ('alice', 1)) t(y, x) + """)) + .returnsEmptyResult(); + + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y) + EXCEPT ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1)) t(y, x) + """)) + .matches("VALUES (1, 'alice')"); + + // Test EXCEPT with different number of columns + assertThat(assertions.query( + """ + SELECT * FROM (VALUES 1) t(x) + EXCEPT CORRESPONDING + SELECT * FROM (VALUES ('alice', 1)) t(y, x) + """)) + .returnsEmptyResult(); + + assertThat(assertions.query( + """ + SELECT * FROM (VALUES ('alice', 1)) t(y, x) + EXCEPT CORRESPONDING + SELECT * FROM (VALUES 1) t(x) + """)) + .returnsEmptyResult(); + + // Test case insensitivity + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice')) t(X, Y) + EXCEPT CORRESPONDING + SELECT * FROM (VALUES ('alice', 1)) t(y, x) + """)) + .returnsEmptyResult(); + } + + @Test + void testUnionCorresponding() + { + assertThat(assertions.query( + """ + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + UNION CORRESPONDING + SELECT 1 AS x, 'alice' AS y + """)) + .matches("VALUES ('alice', 1), ('bob', 2)"); + + assertThat(assertions.query( + """ + SELECT 1 AS x, 'alice' AS y + UNION ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + """)) + .matches("VALUES (1, 'alice'), (1, 'alice'), (2, 'bob')"); + + // Test UNION with different number of columns + assertThat(assertions.query( + """ + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + UNION ALL CORRESPONDING + SELECT 3 AS x + """)) + .matches("VALUES 1, 2, 3"); + + assertThat(assertions.query( + """ + SELECT 3 AS x + UNION ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + """)) + .matches("VALUES 1, 2, 3"); + + // Test case insensitivity + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice')) t(X, Y) + UNION ALL CORRESPONDING + SELECT * FROM (VALUES ('bob', 2)) t(y, x) + """)) + .matches("VALUES (1, 'alice'), (2, 'bob')"); + } + + @Test + void testIntersectCorresponding() + { + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y) + INTERSECT CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x) + """)) + .matches("VALUES (1, 'alice')"); + + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y) + INTERSECT ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x) + """)) + .matches("VALUES (1, 'alice'), (1, 'alice')"); + + // Test INTERSECT with different number of columns + assertThat(assertions.query( + """ + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + INTERSECT ALL CORRESPONDING + SELECT * FROM (VALUES 1) t(x) + """)) + .matches("VALUES 1"); + + assertThat(assertions.query( + """ + SELECT * FROM (VALUES 1) t(x) + INTERSECT ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x) + """)) + .matches("VALUES 1"); + + // Test case insensitivity + assertThat(assertions.query( + """ + SELECT * FROM (VALUES (1, 'alice'), (2, 'bob')) t(X, Y) + INTERSECT ALL CORRESPONDING + SELECT * FROM (VALUES ('alice', 1), ('carol', 3)) t(y, x) + """)) + .matches("VALUES (1, 'alice')"); + } + + @Test + void testCorrespondingDuplicateNames() + { + assertThat(assertions.query("SELECT 1 AS x, 2 AS y EXCEPT CORRESPONDING SELECT 1 AS x, 2 AS X")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + assertThat(assertions.query("SELECT 1 AS x, 2 AS X EXCEPT CORRESPONDING SELECT 1 AS y, 2 AS x")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + + assertThat(assertions.query("SELECT 1 AS x, 2 AS y UNION CORRESPONDING SELECT 1 AS x, 2 AS X")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + assertThat(assertions.query("SELECT 1 AS x, 2 AS X UNION CORRESPONDING SELECT 1 AS x, 2 AS y")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + + assertThat(assertions.query("SELECT 1 AS x, 2 AS y INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS X")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + assertThat(assertions.query("SELECT 1 AS X, 2 AS x INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS y")) + .failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x"); + } + + @Test + void testCorrespondingUnsupportedColumnNames() + { + assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING BY (x) SELECT 2 AS x")) + .failure().hasMessage("line 1:15: CORRESPONDING with columns is unsupported"); + + assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING BY (x) SELECT 2 AS x")) + .failure().hasMessage("line 1:15: CORRESPONDING with columns is unsupported"); + + assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING BY (x) SELECT 2 AS x")) + .failure().hasMessage("line 1:15: CORRESPONDING with columns is unsupported"); + } + + @Test + void testCorrespondingNameMismatch() + { + assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2 AS y")) + .failure().hasMessage("line 1:15: No corresponding columns"); + + assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2 AS y")) + .failure().hasMessage("line 1:15: No corresponding columns"); + + assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2 AS y")) + .failure().hasMessage("line 1:15: No corresponding columns"); + } + + @Test + void testCorrespondingWithAnonymousColumn() + { + assertThat(assertions.query("SELECT 1 EXCEPT CORRESPONDING SELECT 2 AS x")) + .failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING"); + assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2")) + .failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING"); + + assertThat(assertions.query("SELECT 1 UNION CORRESPONDING SELECT 2 AS x")) + .failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING"); + assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2")) + .failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING"); + + assertThat(assertions.query("SELECT 1 INTERSECT CORRESPONDING SELECT 2 AS x")) + .failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING"); + assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2")) + .failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING"); + } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 4b860d5eb9c2..a21b08654fae 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -32,6 +32,7 @@ import io.trino.sql.tree.Commit; import io.trino.sql.tree.CompoundStatement; import io.trino.sql.tree.ControlStatement; +import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateFunction; import io.trino.sql.tree.CreateMaterializedView; @@ -1034,6 +1035,7 @@ protected Void visitUnion(Union node, Integer indent) if (!node.isDistinct()) { builder.append("ALL "); } + appendCorresponding(node.getCorresponding()); } } @@ -1049,6 +1051,7 @@ protected Void visitExcept(Except node, Integer indent) if (!node.isDistinct()) { builder.append("ALL "); } + appendCorresponding(node.getCorresponding()); processRelation(node.getRight(), indent); @@ -1068,12 +1071,26 @@ protected Void visitIntersect(Intersect node, Integer indent) if (!node.isDistinct()) { builder.append("ALL "); } + appendCorresponding(node.getCorresponding()); } } return null; } + private void appendCorresponding(Optional node) + { + node.ifPresent(corresponding -> { + builder.append("CORRESPONDING "); + if (!corresponding.getColumns().isEmpty()) { + builder.append("BY "); + builder.append(corresponding.getColumns().stream() + .map(SqlFormatter::formatName) + .collect(joining(", ", "(", ") "))); + } + }); + } + @Override protected Void visitMerge(Merge node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 06322c2c9234..d798b2ac5820 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -19,6 +19,7 @@ import io.trino.grammar.sql.SqlBaseBaseVisitor; import io.trino.grammar.sql.SqlBaseLexer; import io.trino.grammar.sql.SqlBaseParser; +import io.trino.grammar.sql.SqlBaseParser.CorrespondingContext; import io.trino.grammar.sql.SqlBaseParser.CreateCatalogContext; import io.trino.grammar.sql.SqlBaseParser.DropCatalogContext; import io.trino.sql.tree.AddColumn; @@ -49,6 +50,7 @@ import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.CompoundStatement; import io.trino.sql.tree.ControlStatement; +import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateFunction; import io.trino.sql.tree.CreateMaterializedView; @@ -1384,10 +1386,17 @@ public Node visitSetOperation(SqlBaseParser.SetOperationContext context) boolean distinct = context.setQuantifier() == null || context.setQuantifier().DISTINCT() != null; + CorrespondingContext correspondingContext = context.corresponding(); + Optional corresponding = Optional.empty(); + if (correspondingContext != null) { + List columns = correspondingContext.columnAliases() == null ? List.of() : visit(correspondingContext.columnAliases().identifier(), Identifier.class); + corresponding = Optional.of(new Corresponding(getLocation(correspondingContext), columns)); + } + return switch (context.operator.getType()) { - case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct); - case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct); - case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct); + case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct, corresponding); + case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct, corresponding); + case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct, corresponding); default -> throw new IllegalArgumentException("Unsupported set operation: " + context.operator.getText()); }; } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Corresponding.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Corresponding.java new file mode 100644 index 000000000000..89320b02a757 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Corresponding.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class Corresponding + extends Node +{ + private final List columns; + + public Corresponding(NodeLocation location, List columns) + { + super(location); + this.columns = ImmutableList.copyOf(columns); + } + + public List getColumns() + { + return columns; + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return super.accept(visitor, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public int hashCode() + { + return Objects.hash(columns); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + Corresponding o = (Corresponding) obj; + return Objects.equals(columns, o.columns); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columns", columns) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Except.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Except.java index 7bc839761a86..5a6ca216e6e1 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Except.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Except.java @@ -25,12 +25,12 @@ public class Except extends SetOperation { - private final Relation left; - private final Relation right; + private final QueryBody left; + private final QueryBody right; - public Except(NodeLocation location, Relation left, Relation right, boolean distinct) + public Except(NodeLocation location, QueryBody left, QueryBody right, boolean distinct, Optional corresponding) { - super(Optional.of(location), distinct); + super(Optional.of(location), distinct, corresponding); requireNonNull(left, "left is null"); requireNonNull(right, "right is null"); @@ -73,6 +73,7 @@ public String toString() .add("left", left) .add("right", right) .add("distinct", isDistinct()) + .add("corresponding", getCorresponding()) .toString(); } @@ -88,13 +89,14 @@ public boolean equals(Object obj) Except o = (Except) obj; return Objects.equals(left, o.left) && Objects.equals(right, o.right) && - isDistinct() == o.isDistinct(); + isDistinct() == o.isDistinct() && + Objects.equals(getCorresponding(), o.getCorresponding()); } @Override public int hashCode() { - return Objects.hash(left, right, isDistinct()); + return Objects.hash(left, right, isDistinct(), getCorresponding()); } @Override @@ -104,6 +106,8 @@ public boolean shallowEquals(Node other) return false; } - return this.isDistinct() == ((Except) other).isDistinct(); + Except otherExcept = (Except) other; + return this.isDistinct() == otherExcept.isDistinct() && + Objects.equals(getCorresponding(), otherExcept.getCorresponding()); } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Intersect.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Intersect.java index ba3c446e4482..982190501ed3 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Intersect.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Intersect.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class Intersect @@ -27,20 +28,21 @@ public class Intersect { private final List relations; - public Intersect(List relations, boolean distinct) + public Intersect(List relations, boolean distinct, Optional corresponding) { - this(Optional.empty(), relations, distinct); + this(Optional.empty(), relations, distinct, corresponding); } - public Intersect(NodeLocation location, List relations, boolean distinct) + public Intersect(NodeLocation location, List relations, boolean distinct, Optional corresponding) { - this(Optional.of(location), relations, distinct); + this(Optional.of(location), relations, distinct, corresponding); } - private Intersect(Optional location, List relations, boolean distinct) + private Intersect(Optional location, List relations, boolean distinct, Optional corresponding) { - super(location, distinct); + super(location, distinct, corresponding); requireNonNull(relations, "relations is null"); + checkArgument(relations.size() == 2, "relations must have 2 elements"); this.relations = ImmutableList.copyOf(relations); } @@ -69,6 +71,7 @@ public String toString() return toStringHelper(this) .add("relations", relations) .add("distinct", isDistinct()) + .add("corresponding", getCorresponding()) .toString(); } @@ -83,13 +86,14 @@ public boolean equals(Object obj) } Intersect o = (Intersect) obj; return Objects.equals(relations, o.relations) && - isDistinct() == o.isDistinct(); + isDistinct() == o.isDistinct() && + Objects.equals(getCorresponding(), o.getCorresponding()); } @Override public int hashCode() { - return Objects.hash(relations, isDistinct()); + return Objects.hash(relations, isDistinct(), getCorresponding()); } @Override @@ -99,6 +103,8 @@ public boolean shallowEquals(Node other) return false; } - return this.isDistinct() == ((Intersect) other).isDistinct(); + Intersect otherIntersect = (Intersect) other; + return this.isDistinct() == otherIntersect.isDistinct() && + Objects.equals(getCorresponding(), otherIntersect.getCorresponding()); } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/SetOperation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/SetOperation.java index ccf2382dc307..749bf2be2fa9 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/SetOperation.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/SetOperation.java @@ -16,15 +16,19 @@ import java.util.List; import java.util.Optional; +import static java.util.Objects.requireNonNull; + public abstract class SetOperation extends QueryBody { private final boolean distinct; + private final Optional corresponding; - protected SetOperation(Optional location, boolean distinct) + protected SetOperation(Optional location, boolean distinct, Optional corresponding) { super(location); this.distinct = distinct; + this.corresponding = requireNonNull(corresponding, "corresponding is null"); } public boolean isDistinct() @@ -32,6 +36,11 @@ public boolean isDistinct() return distinct; } + public Optional getCorresponding() + { + return corresponding; + } + @Override public R accept(AstVisitor visitor, C context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Union.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Union.java index 04aaeee9c07a..6ccf69317550 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Union.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Union.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class Union @@ -27,20 +28,21 @@ public class Union { private final List relations; - public Union(List relations, boolean distinct) + public Union(List relations, boolean distinct, Optional corresponding) { - this(Optional.empty(), relations, distinct); + this(Optional.empty(), relations, distinct, corresponding); } - public Union(NodeLocation location, List relations, boolean distinct) + public Union(NodeLocation location, List relations, boolean distinct, Optional corresponding) { - this(Optional.of(location), relations, distinct); + this(Optional.of(location), relations, distinct, corresponding); } - private Union(Optional location, List relations, boolean distinct) + private Union(Optional location, List relations, boolean distinct, Optional corresponding) { - super(location, distinct); + super(location, distinct, corresponding); requireNonNull(relations, "relations is null"); + checkArgument(relations.size() == 2, "relations must have 2 elements"); this.relations = ImmutableList.copyOf(relations); } @@ -60,7 +62,10 @@ public R accept(AstVisitor visitor, C context) @Override public List getChildren() { - return relations; + ImmutableList.Builder builder = ImmutableList.builder(); + builder.addAll(relations); + getCorresponding().ifPresent(builder::add); + return builder.build(); } @Override @@ -69,6 +74,7 @@ public String toString() return toStringHelper(this) .add("relations", relations) .add("distinct", isDistinct()) + .add("corresponding", getCorresponding()) .toString(); } @@ -83,13 +89,14 @@ public boolean equals(Object obj) } Union o = (Union) obj; return Objects.equals(relations, o.relations) && - isDistinct() == o.isDistinct(); + isDistinct() == o.isDistinct() && + Objects.equals(getCorresponding(), o.getCorresponding()); } @Override public int hashCode() { - return Objects.hash(relations, isDistinct()); + return Objects.hash(relations, isDistinct(), getCorresponding()); } @Override @@ -99,6 +106,8 @@ public boolean shallowEquals(Node other) return false; } - return this.isDistinct() == ((Union) other).isDistinct(); + Union otherUnion = (Union) other; + return this.isDistinct() == otherUnion.isDistinct() && + Objects.equals(getCorresponding(), otherUnion.getCorresponding()); } } diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index b3c35ef28ab0..8a63fe9de518 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -38,6 +38,7 @@ import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.Corresponding; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateRole; @@ -66,6 +67,7 @@ import io.trino.sql.tree.DropView; import io.trino.sql.tree.EmptyPattern; import io.trino.sql.tree.EmptyTableTreatment; +import io.trino.sql.tree.Except; import io.trino.sql.tree.Execute; import io.trino.sql.tree.ExecuteImmediate; import io.trino.sql.tree.ExistsPredicate; @@ -850,9 +852,26 @@ public void testIntersect() assertStatement("SELECT 123 INTERSECT DISTINCT SELECT 123 INTERSECT ALL SELECT 123", query(new Intersect( ImmutableList.of( - new Intersect(ImmutableList.of(createSelect123(), createSelect123()), true), + new Intersect(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.empty()), createSelect123()), - false))); + false, + Optional.empty()))); + + assertStatement("SELECT 123 INTERSECT DISTINCT CORRESPONDING SELECT 123 INTERSECT ALL CORRESPONDING SELECT 123", + query(new Intersect( + ImmutableList.of( + new Intersect(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.of(new Corresponding(location(1 ,1), List.of()))), + createSelect123()), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); + + assertStatement("SELECT 123 INTERSECT DISTINCT CORRESPONDING BY (x) SELECT 123 INTERSECT ALL CORRESPONDING SELECT 123", + query(new Intersect( + ImmutableList.of( + new Intersect(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.of(new Corresponding(location(1 ,1), List.of(identifier("x"))))), + createSelect123()), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); } @Test @@ -861,9 +880,54 @@ public void testUnion() assertStatement("SELECT 123 UNION DISTINCT SELECT 123 UNION ALL SELECT 123", query(new Union( ImmutableList.of( - new Union(ImmutableList.of(createSelect123(), createSelect123()), true), + new Union(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.empty()), createSelect123()), - false))); + false, + Optional.empty()))); + + assertStatement("SELECT 123 UNION DISTINCT CORRESPONDING SELECT 123 UNION ALL CORRESPONDING SELECT 123", + query(new Union( + ImmutableList.of( + new Union(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.of(new Corresponding(location(1 ,1), List.of()))), + createSelect123()), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); + + assertStatement("SELECT 123 UNION DISTINCT CORRESPONDING BY (x) SELECT 123 UNION ALL CORRESPONDING SELECT 123", + query(new Union( + ImmutableList.of( + new Union(ImmutableList.of(createSelect123(), createSelect123()), true, Optional.of(new Corresponding(location(1 ,1), List.of(identifier("x"))))), + createSelect123()), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); + } + + @Test + public void testExcept() + { + assertStatement("SELECT 123 EXCEPT DISTINCT SELECT 123 EXCEPT ALL SELECT 123", + query(new Except( + location(1, 1), + new Except(location(1, 1), createSelect123(), createSelect123(), true, Optional.empty()), + createSelect123(), + false, + Optional.empty()))); + + assertStatement("SELECT 123 EXCEPT DISTINCT CORRESPONDING SELECT 123 EXCEPT ALL CORRESPONDING SELECT 123", + query(new Except( + location(1, 1), + new Except(location(1, 1), createSelect123(), createSelect123(), true, Optional.of(new Corresponding(location(1 ,1), List.of()))), + createSelect123(), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); + + assertStatement("SELECT 123 EXCEPT DISTINCT CORRESPONDING BY (x) SELECT 123 EXCEPT ALL CORRESPONDING SELECT 123", + query(new Except( + location(1, 1), + new Except(location(1, 1), createSelect123(), createSelect123(), true, Optional.of(new Corresponding(location(1 ,1), List.of(identifier("x"))))), + createSelect123(), + false, + Optional.of(new Corresponding(location(1 ,1), List.of()))))); } private static QuerySpecification createSelect123() diff --git a/docs/src/main/sphinx/sql/select.md b/docs/src/main/sphinx/sql/select.md index cd32437fc4f1..668e811baa52 100644 --- a/docs/src/main/sphinx/sql/select.md +++ b/docs/src/main/sphinx/sql/select.md @@ -773,15 +773,15 @@ specifications contains the component, the default value is used. to combine the results of more than one select statement into a single result set: ```text -query UNION [ALL | DISTINCT] query +query UNION [ALL | DISTINCT] [CORRESPONDING] query ``` ```text -query INTERSECT [ALL | DISTINCT] query +query INTERSECT [ALL | DISTINCT] [CORRESPONDING] query ``` ```text -query EXCEPT [ALL | DISTINCT] query +query EXCEPT [ALL | DISTINCT] [CORRESPONDING] query ``` The argument `ALL` or `DISTINCT` controls which rows are included in @@ -850,6 +850,36 @@ SELECT * FROM (VALUES 42, 13); (2 rows) ``` +`CORRESPONDING` matches columns by name instead of by position: + +```sql +SELECT * FROM (VALUES (1, 'alice')) AS t(id, name) +UNION ALL CORRESPONDING +SELECT * FROM (VALUES ('bob', 2)) AS t(name, id); +``` + +```text + id | name +----+------- + 1 | alice + 2 | bob +(2 rows) +``` + +```sql +SELECT * FROM (VALUES (DATE '2025-04-23', 'alice')) AS t(order_date, name) +UNION ALL CORRESPONDING +SELECT * FROM (VALUES ('bob', 123.45)) AS t(name, price); +``` + +```text + name +------- + alice + bob +(2 rows) +``` + ### INTERSECT clause `INTERSECT` returns only the rows that are in the result sets of both the first and @@ -871,6 +901,21 @@ SELECT 13; (2 rows) ``` +`CORRESPONDING` matches columns by name instead of by position: + +```sql +SELECT * FROM (VALUES (1, 'alice')) AS t(id, name) +INTERSECT CORRESPONDING +SELECT * FROM (VALUES ('alice', 1)) AS t(name, id); +``` + +```text + id | name +----+------- + 1 | alice +(1 row) +``` + ### EXCEPT clause `EXCEPT` returns the rows that are in the result set of the first query, @@ -892,6 +937,21 @@ SELECT 13; (2 rows) ``` +`CORRESPONDING` matches columns by name instead of by position: + +```sql +SELECT * FROM (VALUES (1, 'alice'), (2, 'bob')) AS t(id, name) +EXCEPT CORRESPONDING +SELECT * FROM (VALUES ('alice', 1)) AS t(name, id); +``` + +```text + id | name +----+------ + 2 | bob +(1 row) +``` + (order-by-clause)= ## ORDER BY clause