diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 33335bae4c96..e5947e2006bd 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -385,11 +385,36 @@ joinCriteria ; sampledRelation - : patternRecognition ( + : pivot ( TABLESAMPLE sampleType '(' percentage=expression ')' )? ; +pivot + : patternRecognition ( + PIVOT '(' + pivotAggregation (',' pivotAggregation)* + FOR pivotColumns IN '(' pivotValueGroup (',' pivotValueGroup)* ')' + (GROUP BY groupBy)? + ')' + (AS? identifier columnAliases?)? + )? + ; + +pivotAggregation + : expression (AS? identifier)? + ; + +pivotColumns + : qualifiedName + | '(' qualifiedName (',' qualifiedName)* ')' + ; + +pivotValueGroup + : '(' expression (',' expression)+ ')' (AS? identifier)? + | expression (AS? identifier)? + ; + sampleType : BERNOULLI | SYSTEM @@ -1059,7 +1084,7 @@ nonReserved | MAP | MATCH | MATCHED | MATCHES | MATCH_RECOGNIZE | MATERIALIZED | MEASURES | MERGE | MINUTE | MONTH | NEAREST | NESTED | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OBJECT | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | OVERFLOW - | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD | PERMUTE | PLAN | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE + | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD | PERMUTE | PIVOT | PLAN | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE | QUOTES | RANGE | READ | REFRESH | RENAME | REPEAT | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | RETURN | RETURNING | RETURNS | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | RUNNING | SCALAR | SCHEMA | SCHEMAS | SECOND | SECURITY | SEEK | SERIALIZABLE | SESSION | SET | SETS @@ -1273,6 +1298,7 @@ PATTERN: 'PATTERN'; PER: 'PER'; PERIOD: 'PERIOD'; PERMUTE: 'PERMUTE'; +PIVOT: 'PIVOT'; PLAN : 'PLAN'; POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; diff --git a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java index a58de81dc7c9..7ac03d863d72 100644 --- a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java +++ b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java @@ -227,6 +227,7 @@ public void test() "PER", "PERIOD", "PERMUTE", + "PIVOT", "PLAN", "POSITION", "PRECEDING", diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index e40a05920abb..3707824d25a2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -80,6 +80,7 @@ import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.Pivot; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.QuantifiedComparisonExpression; import io.trino.sql.tree.Query; @@ -146,6 +147,8 @@ public class Analysis private final Map, Query> namedQueries = new LinkedHashMap<>(); + private final Map, PivotAnalysis> pivotAnalyses = new LinkedHashMap<>(); + // map expandable query to the node being the inner recursive reference private final Map, Node> expandableNamedQueries = new LinkedHashMap<>(); @@ -188,7 +191,7 @@ public class Analysis private final Map, List> aggregates = new LinkedHashMap<>(); private final Map, List> orderByAggregates = new LinkedHashMap<>(); - private final Map, GroupingSetAnalysis> groupingSets = new LinkedHashMap<>(); + private final Map, GroupingSetAnalysis> groupingSets = new LinkedHashMap<>(); private final Map, Expression> where = new LinkedHashMap<>(); private final Map, Expression> having = new LinkedHashMap<>(); @@ -438,7 +441,7 @@ public Map, LambdaArgumentDeclaration> getLambdaArgumentRefe return unmodifiableMap(lambdaArgumentReferences); } - public void setGroupingSets(QuerySpecification node, GroupingSetAnalysis groupingSets) + public void setGroupingSets(Node node, GroupingSetAnalysis groupingSets) { this.groupingSets.put(NodeRef.of(node), groupingSets); } @@ -448,7 +451,7 @@ public boolean isAggregation(QuerySpecification node) return groupingSets.containsKey(NodeRef.of(node)); } - public GroupingSetAnalysis getGroupingSets(QuerySpecification node) + public GroupingSetAnalysis getGroupingSets(Node node) { return groupingSets.get(NodeRef.of(node)); } @@ -895,6 +898,21 @@ public void registerNamedQuery(Table tableReference, Query query) namedQueries.put(NodeRef.of(tableReference), query); } + public void registerPivotAnalysis(Pivot pivot, PivotAnalysis analysis) + { + requireNonNull(pivot, "pivot is null"); + requireNonNull(analysis, "analysis is null"); + + pivotAnalyses.put(NodeRef.of(pivot), analysis); + } + + public PivotAnalysis getPivotAnalysis(Pivot pivot) + { + PivotAnalysis analysis = pivotAnalyses.get(NodeRef.of(pivot)); + checkArgument(analysis != null, "pivot has no analysis registered: %s", pivot); + return analysis; + } + public void registerExpandableQuery(Query query, Node recursiveReference) { requireNonNull(query, "query is null"); @@ -1743,6 +1761,29 @@ public Set getAllFields() } } + public record PivotAnalysis( + GroupingSetAnalysis groupingSetAnalysis, + boolean distinctGroupingSets, + List outputColumns, + List aggregates) + { + public PivotAnalysis + { + requireNonNull(groupingSetAnalysis, "groupingSetAnalysis is null"); + outputColumns = ImmutableList.copyOf(outputColumns); + aggregates = ImmutableList.copyOf(aggregates); + } + } + + public record PivotOutputColumn(String name, Type type) + { + public PivotOutputColumn + { + requireNonNull(name, "name is null"); + requireNonNull(type, "type is null"); + } + } + public static class UnnestAnalysis { private final Map, List> mappings; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 7fb84dcef8e0..89986538119d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -104,6 +104,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeNotFoundException; import io.trino.spi.type.VarcharType; +import io.trino.sql.ExpressionFormatter; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis; @@ -111,6 +112,8 @@ import io.trino.sql.analyzer.Analysis.JsonTableAnalysis; import io.trino.sql.analyzer.Analysis.MergeAnalysis; import io.trino.sql.analyzer.Analysis.NearestAnalysis; +import io.trino.sql.analyzer.Analysis.PivotAnalysis; +import io.trino.sql.analyzer.Analysis.PivotOutputColumn; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.Analysis.SourceColumn; @@ -210,6 +213,9 @@ import io.trino.sql.tree.OrdinalityColumn; import io.trino.sql.tree.Parameter; import io.trino.sql.tree.PatternRecognitionRelation; +import io.trino.sql.tree.Pivot; +import io.trino.sql.tree.PivotAggregation; +import io.trino.sql.tree.PivotValueGroup; import io.trino.sql.tree.PlanLeaf; import io.trino.sql.tree.PlanParentChild; import io.trino.sql.tree.PlanSiblings; @@ -3226,6 +3232,158 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s return createAndAssignScope(relation, scope, relationScope.getRelationType()); } + @Override + protected Scope visitPivot(Pivot relation, Optional scope) + { + validatePivotShape(relation); + + // Analyze the input relation in the surrounding scope. + Scope inputScope = process(relation.getInput(), scope); + + // Resolve pivot columns against the input scope. + for (Expression pivotColumn : relation.getPivotColumns()) { + analysis.recordSubqueries(relation, analyzeExpression(pivotColumn, inputScope)); + } + + // Resolve and validate IN values: each value must be coercible to the corresponding + // pivot column's type. The coercion itself is applied at planning time. + for (PivotValueGroup valueGroup : relation.getValueGroups()) { + for (int i = 0; i < relation.getPivotColumns().size(); i++) { + Expression value = valueGroup.getValues().get(i); + analysis.recordSubqueries(relation, analyzeExpression(value, inputScope)); + Type pivotColumnType = analysis.getType(relation.getPivotColumns().get(i)); + Type valueType = analysis.getType(value); + if (!typeCoercion.canCoerce(valueType, pivotColumnType)) { + throw semanticException( + TYPE_MISMATCH, + value, + "Pivot value of type %s cannot be coerced to pivot column type %s", + valueType, + pivotColumnType); + } + } + } + + // Resolve GROUP BY expressions against the input scope. PIVOT does not have a SELECT + // list, so ordinal references and AUTO grouping are not meaningful here. + GroupingSetAnalysis groupingSetAnalysis; + boolean distinctGroupingSets; + if (relation.getGroupBy().isPresent()) { + groupingSetAnalysis = analyzeGroupingElements(relation, relation.getGroupBy().get(), inputScope, ImmutableList.of()); + distinctGroupingSets = relation.getGroupBy().get().isDistinct(); + } + else { + groupingSetAnalysis = new GroupingSetAnalysis(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); + analysis.setGroupingSets(relation, groupingSetAnalysis); + distinctGroupingSets = false; + } + List groupingExpressions = groupingSetAnalysis.getOriginalExpressions(); + + // Analyze each aggregation slot expression (as written by the user) in the input + // scope. + ImmutableList.Builder slotExpressionsBuilder = ImmutableList.builder(); + for (PivotAggregation aggregation : relation.getAggregations()) { + analysis.recordSubqueries(relation, analyzeExpression(aggregation.getExpression(), inputScope)); + slotExpressionsBuilder.add(aggregation.getExpression()); + } + List slotExpressions = slotExpressionsBuilder.build(); + + // Verify that non-aggregate parts of slot expressions only reference grouping columns. + AggregationAnalyzer.verifySourceAggregations( + groupingExpressions, + inputScope, + slotExpressions, + session, + plannerContext, + accessControl, + analysis); + + // Collect the aggregate function calls inside the slot expressions for the planner. + List aggregates = extractAggregateFunctions(slotExpressions, session, functionResolver, accessControl); + + // Compute output column metadata. The order matches the planner's iteration: + // for each value group, for each aggregation slot. + Set seenNames = new HashSet<>(); + ImmutableList.Builder outputColumns = ImmutableList.builder(); + for (PivotValueGroup valueGroup : relation.getValueGroups()) { + for (PivotAggregation aggregation : relation.getAggregations()) { + String columnName = pivotColumnName(valueGroup, aggregation); + if (!seenNames.add(columnName.toLowerCase(ENGLISH))) { + throw semanticException( + DUPLICATE_COLUMN_NAME, + valueGroup, + "PIVOT produces duplicate output column name: %s", + columnName); + } + outputColumns.add(new PivotOutputColumn(columnName, analysis.getType(aggregation.getExpression()))); + } + } + List pivotOutputColumns = outputColumns.build(); + + // Build output Field list: grouping expressions, then pivot output columns. + ImmutableList.Builder outputFields = ImmutableList.builder(); + for (Expression groupingExpression : groupingExpressions) { + outputFields.add(Field.newUnqualified(deriveColumnName(groupingExpression), analysis.getType(groupingExpression))); + } + for (PivotOutputColumn column : pivotOutputColumns) { + outputFields.add(Field.newUnqualified(column.name(), column.type())); + } + + analysis.registerPivotAnalysis( + relation, + new PivotAnalysis(groupingSetAnalysis, distinctGroupingSets, pivotOutputColumns, aggregates)); + + return createAndAssignScope(relation, scope, outputFields.build()); + } + + private void validatePivotShape(Pivot relation) + { + int pivotColumnArity = relation.getPivotColumns().size(); + for (PivotValueGroup valueGroup : relation.getValueGroups()) { + if (valueGroup.getValues().size() != pivotColumnArity) { + throw semanticException( + INVALID_ARGUMENTS, + valueGroup, + "Number of pivot values (%s) does not match number of pivot columns (%s)", + valueGroup.getValues().size(), + pivotColumnArity); + } + } + + if (relation.getAggregations().size() > 1) { + for (PivotAggregation aggregation : relation.getAggregations()) { + if (aggregation.getAlias().isEmpty()) { + throw semanticException( + MISSING_COLUMN_ALIASES, + aggregation, + "PIVOT with multiple aggregations requires an alias on each aggregation"); + } + } + } + } + + private static String pivotColumnName(PivotValueGroup valueGroup, PivotAggregation aggregation) + { + String valueName = valueGroup.getAlias() + .map(Identifier::getValue) + .orElseGet(() -> valueGroup.getValues().stream() + .map(ExpressionFormatter::formatExpression) + .collect(Collectors.joining("_"))); + String aggregationName = aggregation.getAlias().map(Identifier::getValue).orElse(""); + return aggregationName.isEmpty() ? valueName : valueName + "_" + aggregationName; + } + + private static Optional deriveColumnName(Expression expression) + { + if (expression instanceof Identifier identifier) { + return Optional.of(identifier.getValue()); + } + if (expression instanceof DereferenceExpression dereference) { + return Optional.of(dereference.getField().orElseThrow().getValue()); + } + return Optional.empty(); + } + // this method should run after the `base` relation is processed, so that it is // determined whether the table function is polymorphic private void validateNoNestedTableFunction(Relation base, String context) @@ -4813,111 +4971,120 @@ else if (element instanceof GroupingSets groupingSets) { private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, List outputExpressions) { if (node.getGroupBy().isPresent()) { - ImmutableList.Builder>> cubes = ImmutableList.builder(); - ImmutableList.Builder>> rollups = ImmutableList.builder(); - ImmutableList.Builder>> sets = ImmutableList.builder(); - ImmutableList.Builder complexExpressions = ImmutableList.builder(); - ImmutableList.Builder groupingExpressions = ImmutableList.builder(); - - checkGroupingSetsCount(node.getGroupBy().get()); - for (GroupingElement groupingElement : node.getGroupBy().get().getGroupingElements()) { - if (groupingElement instanceof SimpleGroupBy) { - for (Expression column : groupingElement.getExpressions()) { - // simple GROUP BY expressions allow ordinals or arbitrary expressions - if (column instanceof LongLiteral) { - long ordinal = ((LongLiteral) column).getParsedValue(); - if (ordinal < 1 || ordinal > outputExpressions.size()) { - throw semanticException(INVALID_COLUMN_REFERENCE, column, "GROUP BY position %s is not in select list", ordinal); - } + return analyzeGroupingElements(node, node.getGroupBy().get(), scope, outputExpressions); + } - column = outputExpressions.get(toIntExact(ordinal - 1)); - verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); - } - else { - verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); - analyzeExpression(column, scope); - } + GroupingSetAnalysis result = new GroupingSetAnalysis(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); - ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column)); - if (field != null) { - sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); - } - else { - analysis.recordSubqueries(node, analyzeExpression(column, scope)); - complexExpressions.add(column); - } + if (hasAggregates(node) || node.getHaving().isPresent()) { + analysis.setGroupingSets(node, result); + } - groupingExpressions.add(column); - } - } - else if (groupingElement instanceof AutoGroupBy) { - // Analyze non-aggregation outputs - for (Expression column : outputExpressions) { - if (containsAggregation(column, this::getResolvedFunction)) { - continue; - } - verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); - analyzeExpression(column, scope); + return result; + } - ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column)); - if (field != null) { - sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); - } - else { - analysis.recordSubqueries(node, analyzeExpression(column, scope)); - complexExpressions.add(column); + // Analyzes the elements of a GROUP BY clause and registers a GroupingSetAnalysis + // keyed by `node`. The `outputExpressions` list is used only for resolving ordinal + // references and for the AUTO grouping element; callers without a SELECT list pass + // an empty list, in which case those forms are not allowed. + private GroupingSetAnalysis analyzeGroupingElements(Node node, GroupBy groupBy, Scope scope, List outputExpressions) + { + ImmutableList.Builder>> cubes = ImmutableList.builder(); + ImmutableList.Builder>> rollups = ImmutableList.builder(); + ImmutableList.Builder>> sets = ImmutableList.builder(); + ImmutableList.Builder complexExpressions = ImmutableList.builder(); + ImmutableList.Builder groupingExpressions = ImmutableList.builder(); + + checkGroupingSetsCount(groupBy); + for (GroupingElement groupingElement : groupBy.getGroupingElements()) { + if (groupingElement instanceof SimpleGroupBy) { + for (Expression column : groupingElement.getExpressions()) { + // simple GROUP BY expressions allow ordinals or arbitrary expressions + if (column instanceof LongLiteral) { + long ordinal = ((LongLiteral) column).getParsedValue(); + if (ordinal < 1 || ordinal > outputExpressions.size()) { + throw semanticException(INVALID_COLUMN_REFERENCE, column, "GROUP BY position %s is not in select list", ordinal); } - groupingExpressions.add(column); + column = outputExpressions.get(toIntExact(ordinal - 1)); + verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); } - } - else if (groupingElement instanceof GroupingSets element) { - for (Expression column : groupingElement.getExpressions()) { + else { + verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); analyzeExpression(column, scope); - if (!analysis.getColumnReferences().contains(NodeRef.of(column))) { - throw semanticException(INVALID_COLUMN_REFERENCE, column, "GROUP BY expression must be a column reference: %s", column); - } - - groupingExpressions.add(column); } - List> groupingSets = element.getSets().stream() - .map(set -> set.stream() - .map(NodeRef::of) - .map(analysis.getColumnReferenceFields()::get) - .map(ResolvedField::getFieldId) - .collect(toImmutableSet())) - .collect(toImmutableList()); - - switch (element.getType()) { - case CUBE -> cubes.add(groupingSets); - case ROLLUP -> rollups.add(groupingSets); - case EXPLICIT -> sets.add(groupingSets); + ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column)); + if (field != null) { + sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); + } + else { + analysis.recordSubqueries(node, analyzeExpression(column, scope)); + complexExpressions.add(column); } + + groupingExpressions.add(column); } } + else if (groupingElement instanceof AutoGroupBy) { + // Analyze non-aggregation outputs + for (Expression column : outputExpressions) { + if (containsAggregation(column, this::getResolvedFunction)) { + continue; + } + verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause"); + analyzeExpression(column, scope); - List expressions = groupingExpressions.build(); - for (Expression expression : expressions) { - Type type = analysis.getType(expression); - if (!type.isComparable()) { - throw semanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in GROUP BY", type); + ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column)); + if (field != null) { + sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); + } + else { + analysis.recordSubqueries(node, analyzeExpression(column, scope)); + complexExpressions.add(column); + } + + groupingExpressions.add(column); } } + else if (groupingElement instanceof GroupingSets element) { + for (Expression column : groupingElement.getExpressions()) { + analyzeExpression(column, scope); + if (!analysis.getColumnReferences().contains(NodeRef.of(column))) { + throw semanticException(INVALID_COLUMN_REFERENCE, column, "GROUP BY expression must be a column reference: %s", column); + } - GroupingSetAnalysis groupingSets = new GroupingSetAnalysis(expressions, cubes.build(), rollups.build(), sets.build(), complexExpressions.build()); - analysis.setGroupingSets(node, groupingSets); + groupingExpressions.add(column); + } - return groupingSets; + List> groupingSets = element.getSets().stream() + .map(set -> set.stream() + .map(NodeRef::of) + .map(analysis.getColumnReferenceFields()::get) + .map(ResolvedField::getFieldId) + .collect(toImmutableSet())) + .collect(toImmutableList()); + + switch (element.getType()) { + case CUBE -> cubes.add(groupingSets); + case ROLLUP -> rollups.add(groupingSets); + case EXPLICIT -> sets.add(groupingSets); + } + } } - GroupingSetAnalysis result = new GroupingSetAnalysis(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); - - if (hasAggregates(node) || node.getHaving().isPresent()) { - analysis.setGroupingSets(node, result); + List expressions = groupingExpressions.build(); + for (Expression expression : expressions) { + Type type = analysis.getType(expression); + if (!type.isComparable()) { + throw semanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in GROUP BY", type); + } } - return result; + GroupingSetAnalysis groupingSets = new GroupingSetAnalysis(expressions, cubes.build(), rollups.build(), sets.build(), complexExpressions.build()); + analysis.setGroupingSets(node, groupingSets); + + return groupingSets; } private boolean hasAggregates(QuerySpecification node) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index a608a3a694a7..0ad63130fe65 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -38,6 +38,8 @@ import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; import io.trino.sql.analyzer.Analysis.MergeAnalysis; +import io.trino.sql.analyzer.Analysis.PivotAnalysis; +import io.trino.sql.analyzer.Analysis.PivotOutputColumn; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.FieldId; @@ -51,6 +53,7 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; import io.trino.sql.ir.Row; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.RelationPlanner.PatternRecognitionComponents; @@ -98,6 +101,9 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; +import io.trino.sql.tree.Pivot; +import io.trino.sql.tree.PivotAggregation; +import io.trino.sql.tree.PivotValueGroup; import io.trino.sql.tree.Query; import io.trino.sql.tree.QuerySpecification; import io.trino.sql.tree.Relation; @@ -483,6 +489,157 @@ public RelationPlan plan(QuerySpecification node) outerContext); } + public RelationPlan planPivot(Pivot node, RelationPlan inputPlan) + { + PivotAnalysis pivotAnalysis = analysis.getPivotAnalysis(node); + GroupingSetAnalysis groupingSetAnalysis = pivotAnalysis.groupingSetAnalysis(); + List aggregateCalls = pivotAnalysis.aggregates(); + + PlanBuilder subPlan = newPlanBuilder(inputPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); + + // Project everything the aggregation predicate and arguments need: aggregate + // arguments, complex grouping expressions, pivot column references, and pivot + // value expressions. + ImmutableList.Builder inputBuilder = ImmutableList.builder(); + for (FunctionCall aggregate : aggregateCalls) { + aggregate.getArguments().stream() + .filter(argument -> !(argument instanceof LambdaExpression)) + .forEach(inputBuilder::add); + } + inputBuilder.addAll(groupingSetAnalysis.getComplexExpressions()); + inputBuilder.addAll(node.getPivotColumns()); + for (PivotValueGroup valueGroup : node.getValueGroups()) { + inputBuilder.addAll(valueGroup.getValues()); + } + List inputs = inputBuilder.build(); + + subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, analysis.getSubqueries(node)); + subPlan = subPlan.appendProjections(inputs, symbolAllocator, idAllocator); + + PlanAndMappings coercions = coerce(subPlan, inputs, analysis, idAllocator, symbolAllocator); + subPlan = coercions.getSubPlan(); + + // Build a boolean predicate symbol per value group: pivot_col_1 = value_1 AND ... + // Cast each value to the corresponding pivot column's type so comparison types align, + // including the unknown-typed NULL literal. + Map, Symbol> predicateSymbols = new LinkedHashMap<>(); + Assignments.Builder predicateAssignments = Assignments.builder(); + predicateAssignments.putIdentities(subPlan.getRoot().getOutputSymbols()); + for (PivotValueGroup valueGroup : node.getValueGroups()) { + Expression predicate = null; + for (int i = 0; i < node.getPivotColumns().size(); i++) { + Symbol columnSymbol = coercions.get(node.getPivotColumns().get(i)); + Symbol valueSymbol = coercions.get(valueGroup.getValues().get(i)); + Expression valueExpression = valueSymbol.toSymbolReference(); + if (!valueSymbol.type().equals(columnSymbol.type())) { + valueExpression = new Cast(valueExpression, columnSymbol.type()); + } + Expression equality = new Comparison(Comparison.Operator.EQUAL, columnSymbol.toSymbolReference(), valueExpression); + predicate = (predicate == null) ? equality : Logical.and(predicate, equality); + } + Symbol predicateSymbol = symbolAllocator.newSymbol("pivot_match", BOOLEAN); + predicateAssignments.put(predicateSymbol, predicate); + predicateSymbols.put(NodeRef.of(valueGroup), predicateSymbol); + } + subPlan = subPlan.withNewRoot(new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), predicateAssignments.build())); + + GroupingSetsPlan groupingSets = planGroupingSets(subPlan, pivotAnalysis.distinctGroupingSets(), groupingSetAnalysis); + subPlan = groupingSets.getSubPlan(); + + // Build one Aggregation per (value group, aggregate call) with FILTER set to the + // matching value group's predicate symbol. Each (value group, call) pair gets its + // own output Symbol so the slot expressions can be projected per group. + Map, Symbol>> aggregateSymbolsByGroup = new LinkedHashMap<>(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + PlanBuilder finalSubPlan = subPlan; + for (PivotValueGroup valueGroup : node.getValueGroups()) { + Map, Symbol> perGroup = new LinkedHashMap<>(); + Symbol filterSymbol = predicateSymbols.get(NodeRef.of(valueGroup)); + for (FunctionCall function : aggregateCalls) { + Symbol output = symbolAllocator.newSymbol(function.getName().toString(), analysis.getType(function)); + Aggregation aggregation = new Aggregation( + analysis.getResolvedFunction(function).orElseThrow(), + function.getArguments().stream() + .map(argument -> { + if (argument instanceof LambdaExpression) { + return finalSubPlan.rewrite(argument); + } + return coercions.get(argument).toSymbolReference(); + }) + .collect(toImmutableList()), + function.isDistinct(), + Optional.of(filterSymbol), + function.getOrderBy().map(orderBy -> translateOrderingScheme(orderBy.getSortItems(), coercions::get)), + Optional.empty()); + aggregations.put(output, aggregation); + perGroup.put(NodeRef.of(function), output); + } + aggregateSymbolsByGroup.put(valueGroup, perGroup); + } + + ImmutableSet.Builder globalGroupingSets = ImmutableSet.builder(); + for (int i = 0; i < groupingSets.getGroupingSets().size(); i++) { + if (groupingSets.getGroupingSets().get(i).isEmpty()) { + globalGroupingSets.add(i); + } + } + ImmutableSet.Builder groupingKeys = ImmutableSet.builder(); + groupingSets.getGroupingSets().stream() + .flatMap(List::stream) + .distinct() + .forEach(groupingKeys::add); + groupingSets.getGroupIdSymbol().ifPresent(groupingKeys::add); + + AggregationNode aggregationNode = new AggregationNode( + idAllocator.getNextId(), + subPlan.getRoot(), + aggregations.buildKeepingLast(), + groupingSets( + groupingKeys.build(), + groupingSets.getGroupingSets().size(), + globalGroupingSets.build()), + ImmutableList.of(), + AggregationNode.Step.SINGLE, + groupingSets.getGroupIdSymbol()); + subPlan = new PlanBuilder(subPlan.getTranslations(), aggregationNode); + + // Final projection: grouping expressions identity-projected, then for each + // (value group, slot) translate the user's slot expression with a per-group + // aggregate-call -> aggregation-symbol mapping, producing the synthesized output + // column. + Assignments.Builder outputAssignments = Assignments.builder(); + ImmutableList.Builder outputSymbols = ImmutableList.builder(); + + for (io.trino.sql.tree.Expression groupingExpression : groupingSetAnalysis.getOriginalExpressions()) { + Symbol symbol = subPlan.translate(groupingExpression); + outputAssignments.putIdentity(symbol); + outputSymbols.add(symbol); + } + + int outputColumnIndex = 0; + List outputColumnMetadata = pivotAnalysis.outputColumns(); + for (PivotValueGroup valueGroup : node.getValueGroups()) { + Map, Symbol> perGroupSymbols = aggregateSymbolsByGroup.get(valueGroup); + ImmutableMap.Builder, Symbol> mappings = ImmutableMap.builder(); + for (FunctionCall function : aggregateCalls) { + mappings.put(scopeAwareKey(function, analysis, subPlan.getScope()), perGroupSymbols.get(NodeRef.of(function))); + } + TranslationMap groupTranslations = subPlan.getTranslations().withAdditionalMappings(mappings.buildKeepingLast()); + + for (PivotAggregation aggregation : node.getAggregations()) { + Expression slotIr = groupTranslations.rewrite(aggregation.getExpression()); + PivotOutputColumn column = outputColumnMetadata.get(outputColumnIndex++); + Symbol output = symbolAllocator.newSymbol(column.name(), column.type()); + outputAssignments.put(output, slotIr); + outputSymbols.add(output); + } + } + + ProjectNode outputProject = new ProjectNode(idAllocator.getNextId(), aggregationNode, outputAssignments.build()); + + return new RelationPlan(outputProject, analysis.getScope(node), outputSymbols.build(), outerContext); + } + private static boolean hasExpressionsToUnfold(List selectExpressions) { return selectExpressions.stream() @@ -1186,14 +1343,15 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) PlanAndMappings coercions = coerce(subPlan, inputs, analysis, idAllocator, symbolAllocator); subPlan = coercions.getSubPlan(); - GroupingSetsPlan groupingSets = planGroupingSets(subPlan, node, groupingSetAnalysis); + boolean distinctGroupingSets = node.getGroupBy().isPresent() && node.getGroupBy().get().isDistinct(); + GroupingSetsPlan groupingSets = planGroupingSets(subPlan, distinctGroupingSets, groupingSetAnalysis); subPlan = planAggregation(groupingSets.getSubPlan(), groupingSets.getGroupingSets(), groupingSets.getGroupIdSymbol(), analysis.getAggregates(node), coercions::get); return planGroupingOperations(subPlan, node, groupingSets.getGroupIdSymbol(), groupingSets.getColumnOnlyGroupingSets()); } - private GroupingSetsPlan planGroupingSets(PlanBuilder subPlan, QuerySpecification node, GroupingSetAnalysis groupingSetAnalysis) + private GroupingSetsPlan planGroupingSets(PlanBuilder subPlan, boolean distinctGroupingSets, GroupingSetAnalysis groupingSetAnalysis) { Map groupingSetMappings = new LinkedHashMap<>(); @@ -1229,7 +1387,7 @@ private GroupingSetsPlan planGroupingSets(PlanBuilder subPlan, QuerySpecificatio // This tracks the grouping sets before complex expressions are considered. // It's also used to compute the descriptors needed to implement grouping() List> columnOnlyGroupingSets = enumerateGroupingSets(groupingSetAnalysis); - if (node.getGroupBy().isPresent() && node.getGroupBy().get().isDistinct()) { + if (distinctGroupingSets) { columnOnlyGroupingSets = columnOnlyGroupingSets.stream() .distinct() .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index eeede96270dc..ea0edd7c649e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -135,6 +135,7 @@ import io.trino.sql.tree.OrdinalityColumn; import io.trino.sql.tree.PatternRecognitionRelation; import io.trino.sql.tree.PatternSearchMode; +import io.trino.sql.tree.Pivot; import io.trino.sql.tree.PlanLeaf; import io.trino.sql.tree.PlanParentChild; import io.trino.sql.tree.PlanSiblings; @@ -1867,6 +1868,14 @@ protected RelationPlan visitTableSubquery(TableSubquery node, Void context) return new RelationPlan(plan.getRoot(), analysis.getScope(node), plan.getFieldMappings(), outerContext); } + @Override + protected RelationPlan visitPivot(Pivot node, Void context) + { + RelationPlan inputPlan = process(node.getInput(), context); + return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, plannerContext, outerContext, session, recursiveSubqueries) + .planPivot(node, inputPlan); + } + @Override protected RelationPlan visitQuery(Query node, Void context) { diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestPivot.java b/core/trino-main/src/test/java/io/trino/sql/query/TestPivot.java new file mode 100644 index 000000000000..08970dfc5ab7 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestPivot.java @@ -0,0 +1,388 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import static io.trino.spi.StandardErrorCode.DUPLICATE_COLUMN_NAME; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static io.trino.spi.StandardErrorCode.MISSING_COLUMN_ALIASES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) +public class TestPivot +{ + private static final String SALES = + """ + (VALUES + ('NA', 1, BIGINT '100'), + ('NA', 1, BIGINT '50'), + ('NA', 2, BIGINT '70'), + ('EU', 1, BIGINT '40'), + ('EU', 3, BIGINT '20') + ) AS sales(region, month, amount) + """; + + private final QueryAssertions assertions = new QueryAssertions(); + + @AfterAll + public void teardown() + { + assertions.close(); + } + + @Test + public void testSingleAggregationNoGroupBy() + { + // No implicit grouping: collapses to a single row. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan, 2 AS feb, 3 AS mar)) + """.formatted(SALES))) + .matches("VALUES (BIGINT '190', BIGINT '70', BIGINT '20')"); + } + + @Test + public void testSingleAggregationWithGroupBy() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan, 2 AS feb) GROUP BY region) + ORDER BY region + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + ('EU', BIGINT '40', CAST(NULL AS BIGINT)), + ('NA', BIGINT '150', BIGINT '70') + """); + } + + @Test + public void testMultipleAggregations() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) AS total, + count(amount) AS cnt + FOR month IN (1 AS jan, 2 AS feb) + GROUP BY region + ) + ORDER BY region + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + ('EU', BIGINT '40', BIGINT '1', CAST(NULL AS BIGINT), BIGINT '0'), + ('NA', BIGINT '150', BIGINT '2', BIGINT '70', BIGINT '1') + """); + } + + @Test + public void testMultiplePivotColumns() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) + FOR (region, month) IN ( + ('NA', 1) AS na_jan, + ('NA', 2) AS na_feb, + ('EU', 1) AS eu_jan + ) + ) + """.formatted(SALES))) + .matches("VALUES (BIGINT '150', BIGINT '70', BIGINT '40')"); + } + + @Test + public void testAggregationExpression() + { + // Aggregation slot is a general expression containing aggregates; + // FILTER is attached to each aggregate inside. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) - max(amount) AS minor + FOR month IN (1 AS jan) + GROUP BY region + ) + ORDER BY region + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + ('EU', BIGINT '0'), + ('NA', BIGINT '50') + """); + } + + @Test + public void testValueWithoutAlias() + { + // Unaliased values use literal SQL text as the column name, so 1 and '1' + // produce distinct columns (both nominally compare to month though only one + // matches month's bigint type). + assertThat(assertions.query( + """ + SELECT "1" + FROM %s + PIVOT (sum(amount) FOR month IN (1)) + """.formatted(SALES))) + .matches("VALUES BIGINT '190'"); + } + + @Test + public void testNullValue() + { + // NULL value uses Trino's standard '=' semantics: never matches, so the + // synthesized column is the empty-input aggregation result (0 for count). + // Single-agg with both value-alias and agg-alias produces "valueAlias_aggAlias". + assertThat(assertions.query( + """ + SELECT cnt_null_cnt + FROM %s + PIVOT (count(amount) AS cnt FOR month IN (NULL AS cnt_null)) + """.formatted(SALES))) + .matches("VALUES BIGINT '0'"); + } + + @Test + public void testRelationAlias() + { + assertThat(assertions.query( + """ + SELECT p.r, p.jan + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan) GROUP BY region) AS p (r, jan) + ORDER BY p.r + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + ('EU', BIGINT '40'), + ('NA', BIGINT '150') + """); + } + + @Test + public void testPivotOfPivot() + { + // PIVOT output is a relation; another PIVOT can apply to it. The first PIVOT + // is wrapped in a subquery so the second one can attach. + assertThat(assertions.query( + """ + SELECT * + FROM ( + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan, 2 AS feb) GROUP BY region) + ) PIVOT (sum(jan) FOR region IN ('NA' AS na_total)) + """.formatted(SALES))) + .matches("VALUES (BIGINT '150')"); + } + + @Test + public void testCoercion() + { + // Pivot column is BIGINT; integer literal in IN coerces. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan)) + """.formatted(SALES))) + .matches("VALUES BIGINT '190'"); + } + + @Test + public void testMultiplePivotColumnsTupleArityMismatch() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR (region, month) IN (('NA', 1), ('EU'))) + """.formatted(SALES))) + .failure() + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessageContaining("Number of pivot values"); + } + + @Test + public void testMultipleAggregationsRequireAlias() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount), count(amount) FOR month IN (1)) + """.formatted(SALES))) + .failure() + .hasErrorCode(MISSING_COLUMN_ALIASES) + .hasMessageContaining("PIVOT with multiple aggregations requires an alias"); + } + + @Test + public void testDuplicateOutputColumnName() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 AS jan, 2 AS jan)) + """.formatted(SALES))) + .failure() + .hasErrorCode(DUPLICATE_COLUMN_NAME) + .hasMessageContaining("PIVOT produces duplicate output column name"); + } + + @Test + public void testGroupingSetsInsidePivot() + { + // GROUP BY GROUPING SETS within PIVOT — emits a row per grouping set with NULL for + // the missing dimensions. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) AS total + FOR month IN (1 AS jan) + GROUP BY GROUPING SETS ((region), ()) + ) + ORDER BY region NULLS FIRST + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + (CAST(NULL AS varchar(2)), BIGINT '190'), + ('EU', BIGINT '40'), + ('NA', BIGINT '150') + """); + } + + @Test + public void testCubeInsidePivot() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) AS total + FOR month IN (1 AS jan) + GROUP BY CUBE (region) + ) + ORDER BY region NULLS FIRST + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + (CAST(NULL AS varchar(2)), BIGINT '190'), + ('EU', BIGINT '40'), + ('NA', BIGINT '150') + """); + } + + @Test + public void testRollupInsidePivot() + { + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) AS total + FOR month IN (1 AS jan) + GROUP BY ROLLUP (region) + ) + ORDER BY region NULLS FIRST + """.formatted(SALES))) + .ordered() + .matches( + """ + VALUES + (CAST(NULL AS varchar(2)), BIGINT '190'), + ('EU', BIGINT '40'), + ('NA', BIGINT '150') + """); + } + + @Test + public void testEmptyGroupBy() + { + // GROUP BY () inside PIVOT — equivalent to no GROUP BY: collapses to a single row. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT ( + sum(amount) AS total + FOR month IN (1 AS jan) + GROUP BY () + ) + """.formatted(SALES))) + .matches("VALUES (BIGINT '190')"); + } + + @Test + public void testValueIsExpression() + { + // IN values can be arbitrary constant expressions; coercion follows the + // pivot column's type. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (sum(amount) FOR month IN (1 + 0 AS jan, 2 * 1 AS feb)) + """.formatted(SALES))) + .matches("VALUES (BIGINT '190', BIGINT '70')"); + } + + @Test + public void testAggregationSlotMustBeAggregating() + { + // No aggregate function in the slot — caught by AggregationAnalyzer in the + // rewritten query. + assertThat(assertions.query( + """ + SELECT * + FROM %s + PIVOT (amount FOR month IN (1 AS jan) GROUP BY region) + """.formatted(SALES))) + .failure() + .hasMessageContaining("must be an aggregate expression or appear in GROUP BY clause"); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 0796986becaa..23a39509d55b 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -111,6 +111,9 @@ import io.trino.sql.tree.OrdinalityColumn; import io.trino.sql.tree.ParameterDeclaration; import io.trino.sql.tree.PatternRecognitionRelation; +import io.trino.sql.tree.Pivot; +import io.trino.sql.tree.PivotAggregation; +import io.trino.sql.tree.PivotValueGroup; import io.trino.sql.tree.PlanLeaf; import io.trino.sql.tree.PlanParentChild; import io.trino.sql.tree.PlanSiblings; @@ -1007,7 +1010,7 @@ protected Void visitSampledRelation(SampledRelation node, Integer indent) private void processRelationSuffix(Relation relation, Integer indent) { - if ((relation instanceof AliasedRelation) || (relation instanceof SampledRelation) || (relation instanceof PatternRecognitionRelation)) { + if ((relation instanceof AliasedRelation) || (relation instanceof SampledRelation) || (relation instanceof PatternRecognitionRelation) || (relation instanceof Pivot)) { builder.append("( "); process(relation, indent + 1); append(indent, ")"); @@ -1017,6 +1020,64 @@ private void processRelationSuffix(Relation relation, Integer indent) } } + @Override + protected Void visitPivot(Pivot node, Integer indent) + { + processRelationSuffix(node.getInput(), indent); + + builder.append(" PIVOT (\n"); + append(indent + 1, node.getAggregations().stream() + .map(this::formatPivotAggregation) + .collect(joining(", "))) + .append("\n"); + append(indent + 1, "FOR ") + .append(formatPivotColumns(node.getPivotColumns())) + .append(" IN (") + .append(node.getValueGroups().stream() + .map(this::formatPivotValueGroup) + .collect(joining(", "))) + .append(")\n"); + node.getGroupBy().ifPresent(groupBy -> + append(indent + 1, "GROUP BY " + (groupBy.isDistinct() ? "DISTINCT " : "") + formatGroupBy(groupBy.getGroupingElements())) + .append("\n")); + append(indent, ")"); + return null; + } + + private String formatPivotAggregation(PivotAggregation aggregation) + { + String result = formatExpression(aggregation.getExpression()); + if (aggregation.getAlias().isPresent()) { + result += " AS " + formatName(aggregation.getAlias().get()); + } + return result; + } + + private String formatPivotColumns(List pivotColumns) + { + String formatted = pivotColumns.stream() + .map(SqlFormatter::formatExpression) + .collect(joining(", ")); + if (pivotColumns.size() == 1) { + return formatted; + } + return "(" + formatted + ")"; + } + + private String formatPivotValueGroup(PivotValueGroup valueGroup) + { + String formatted = valueGroup.getValues().stream() + .map(SqlFormatter::formatExpression) + .collect(joining(", ")); + if (valueGroup.getValues().size() != 1) { + formatted = "(" + formatted + ")"; + } + if (valueGroup.getAlias().isPresent()) { + formatted += " AS " + formatName(valueGroup.getAlias().get()); + } + return formatted; + } + @Override protected Void visitValues(Values node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index b3155362bc79..18fabfd80af9 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -211,6 +211,9 @@ import io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch; import io.trino.sql.tree.PatternSearchMode; import io.trino.sql.tree.PatternVariable; +import io.trino.sql.tree.Pivot; +import io.trino.sql.tree.PivotAggregation; +import io.trino.sql.tree.PivotValueGroup; import io.trino.sql.tree.PlanLeaf; import io.trino.sql.tree.PlanParentChild; import io.trino.sql.tree.PlanSiblings; @@ -2004,7 +2007,7 @@ else if (context.joinType().FULL() != null) { @Override public Node visitSampledRelation(SqlBaseParser.SampledRelationContext context) { - Relation child = (Relation) visit(context.patternRecognition()); + Relation child = (Relation) visit(context.pivot()); if (context.TABLESAMPLE() == null) { return child; @@ -2067,6 +2070,70 @@ public Node visitMeasureDefinition(SqlBaseParser.MeasureDefinitionContext contex return new MeasureDefinition(getLocation(context), (Expression) visit(context.expression()), (Identifier) visit(context.identifier())); } + @Override + public Node visitPivot(SqlBaseParser.PivotContext context) + { + Relation child = (Relation) visit(context.patternRecognition()); + + if (context.PIVOT() == null) { + return child; + } + + List aggregations = visit(context.pivotAggregation(), PivotAggregation.class); + List pivotColumns = buildPivotColumns(context.pivotColumns()); + List valueGroups = visit(context.pivotValueGroup(), PivotValueGroup.class); + + Optional groupBy = Optional.empty(); + if (context.GROUP() != null) { + groupBy = Optional.of((GroupBy) visit(context.groupBy())); + } + + Pivot pivot = new Pivot(getLocation(context), child, aggregations, pivotColumns, valueGroups, groupBy); + + if (context.identifier() == null) { + return pivot; + } + + List aliases = null; + if (context.columnAliases() != null) { + aliases = visit(context.columnAliases().identifier(), Identifier.class); + } + + return new AliasedRelation(getLocation(context), pivot, (Identifier) visit(context.identifier()), aliases); + } + + @Override + public Node visitPivotAggregation(SqlBaseParser.PivotAggregationContext context) + { + Optional alias = Optional.empty(); + if (context.identifier() != null) { + alias = Optional.of((Identifier) visit(context.identifier())); + } + return new PivotAggregation(getLocation(context), (Expression) visit(context.expression()), alias); + } + + @Override + public Node visitPivotValueGroup(SqlBaseParser.PivotValueGroupContext context) + { + Optional alias = Optional.empty(); + if (context.identifier() != null) { + alias = Optional.of((Identifier) visit(context.identifier())); + } + return new PivotValueGroup( + getLocation(context), + visit(context.expression(), Expression.class), + alias); + } + + private List buildPivotColumns(SqlBaseParser.PivotColumnsContext context) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (SqlBaseParser.QualifiedNameContext qualifiedNameContext : context.qualifiedName()) { + builder.add(DereferenceExpression.from(getQualifiedName(qualifiedNameContext))); + } + return builder.build(); + } + private static Optional getRowsPerMatch(SqlBaseParser.RowsPerMatchContext context) { if (context == null) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index a0d0f4acaa70..bb53c8378587 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -537,6 +537,21 @@ protected R visitSampledRelation(SampledRelation node, C context) return visitRelation(node, context); } + protected R visitPivot(Pivot node, C context) + { + return visitRelation(node, context); + } + + protected R visitPivotAggregation(PivotAggregation node, C context) + { + return visitNode(node, context); + } + + protected R visitPivotValueGroup(PivotValueGroup node, C context) + { + return visitNode(node, context); + } + protected R visitJoin(Join node, C context) { return visitRelation(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Pivot.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Pivot.java new file mode 100644 index 000000000000..96d72fdcfb10 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Pivot.java @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class Pivot + extends Relation +{ + private final Relation input; + private final List aggregations; + private final List pivotColumns; + private final List valueGroups; + private final Optional groupBy; + + public Pivot( + NodeLocation location, + Relation input, + List aggregations, + List pivotColumns, + List valueGroups, + Optional groupBy) + { + super(Optional.of(location)); + this.input = requireNonNull(input, "input is null"); + this.aggregations = ImmutableList.copyOf(aggregations); + checkArgument(!this.aggregations.isEmpty(), "aggregations is empty"); + this.pivotColumns = ImmutableList.copyOf(pivotColumns); + checkArgument(!this.pivotColumns.isEmpty(), "pivotColumns is empty"); + this.valueGroups = ImmutableList.copyOf(valueGroups); + checkArgument(!this.valueGroups.isEmpty(), "valueGroups is empty"); + this.groupBy = requireNonNull(groupBy, "groupBy is null"); + } + + public Relation getInput() + { + return input; + } + + public List getAggregations() + { + return aggregations; + } + + public List getPivotColumns() + { + return pivotColumns; + } + + public List getValueGroups() + { + return valueGroups; + } + + public Optional getGroupBy() + { + return groupBy; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitPivot(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .add(input) + .addAll(aggregations) + .addAll(pivotColumns) + .addAll(valueGroups) + .addAll(groupBy.stream().toList()) + .build(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("input", input) + .add("aggregations", aggregations) + .add("pivotColumns", pivotColumns) + .add("valueGroups", valueGroups) + .add("groupBy", groupBy.orElse(null)) + .omitNullValues() + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Pivot pivot = (Pivot) o; + return Objects.equals(input, pivot.input) && + Objects.equals(aggregations, pivot.aggregations) && + Objects.equals(pivotColumns, pivot.pivotColumns) && + Objects.equals(valueGroups, pivot.valueGroups) && + Objects.equals(groupBy, pivot.groupBy); + } + + @Override + public int hashCode() + { + return Objects.hash(input, aggregations, pivotColumns, valueGroups, groupBy); + } + + @Override + public boolean shallowEquals(Node other) + { + return sameClass(this, other); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/PivotAggregation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/PivotAggregation.java new file mode 100644 index 000000000000..ef2e1a9dfa91 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/PivotAggregation.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class PivotAggregation + extends Node +{ + private final Expression expression; + private final Optional alias; + + public PivotAggregation(NodeLocation location, Expression expression, Optional alias) + { + super(location); + this.expression = requireNonNull(expression, "expression is null"); + this.alias = requireNonNull(alias, "alias is null"); + } + + public Expression getExpression() + { + return expression; + } + + public Optional getAlias() + { + return alias; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitPivotAggregation(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add(expression); + alias.ifPresent(builder::add); + return builder.build(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .add("alias", alias.orElse(null)) + .omitNullValues() + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PivotAggregation that = (PivotAggregation) o; + return Objects.equals(expression, that.expression) && + Objects.equals(alias, that.alias); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, alias); + } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + return Objects.equals(alias, ((PivotAggregation) other).alias); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/PivotValueGroup.java b/core/trino-parser/src/main/java/io/trino/sql/tree/PivotValueGroup.java new file mode 100644 index 000000000000..888f23739c4a --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/PivotValueGroup.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class PivotValueGroup + extends Node +{ + private final List values; + private final Optional alias; + + public PivotValueGroup(NodeLocation location, List values, Optional alias) + { + super(location); + requireNonNull(values, "values is null"); + checkArgument(!values.isEmpty(), "values is empty"); + this.values = ImmutableList.copyOf(values); + this.alias = requireNonNull(alias, "alias is null"); + } + + public List getValues() + { + return values; + } + + public Optional getAlias() + { + return alias; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitPivotValueGroup(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.addAll(values); + alias.ifPresent(builder::add); + return builder.build(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("values", values) + .add("alias", alias.orElse(null)) + .omitNullValues() + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PivotValueGroup that = (PivotValueGroup) o; + return Objects.equals(values, that.values) && + Objects.equals(alias, that.alias); + } + + @Override + public int hashCode() + { + return Objects.hash(values, alias); + } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + return Objects.equals(alias, ((PivotValueGroup) other).alias); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 2394c6a35961..23ff6b540c18 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -7927,6 +7927,24 @@ public void testSetSessionAuthorization() .isEqualTo(new SetSessionAuthorization(location(1, 1), new StringLiteral(location(1, 27), "null"))); } + @Test + public void testPivot() + { + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) FOR month IN (1, 2))"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) AS total FOR month IN (1 AS jan, 2 AS feb))"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) AS total, avg(amount) AS mean FOR month IN (1 AS jan, 2 AS feb))"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) FOR (region, month) IN (('NA', 1), ('EU', 1), ('NA', 2) AS na_feb))"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) FOR month IN (1) GROUP BY region)"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) FOR month IN (1)) AS p"); + assertRoundtrip("SELECT * FROM t PIVOT (sum(amount) FOR month IN (1)) AS p (r, jan)"); + assertRoundtrip("SELECT pivot AS x FROM (SELECT 1 AS pivot)"); + } + + private static void assertRoundtrip(@Language("SQL") String sql) + { + assertFormattedSql(SQL_PARSER, SQL_PARSER.createStatement(sql)); + } + @Test public void testResetSessionAuthorization() { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index f2a29b17b208..8952647b6610 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -53,7 +53,7 @@ private static Stream statements() "line 1:15: mismatched input '''. Expecting: '(', 'JSON_TABLE', 'LATERAL', 'NEAREST', 'TABLE', 'UNNEST', "), Arguments.of("select *\nfrom x\nfrom", "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', " + - "'LIMIT', 'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', , "), + "'LIMIT', 'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'PIVOT', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', , "), Arguments.of( "select *\nfrom x\nwhere from", "line 3:7: mismatched input 'from'. Expecting: "), @@ -158,7 +158,7 @@ private static Stream statements() "line 1:23: mismatched input ''. Expecting: 'WHERE'"), Arguments.of("SELECT * FROM t t x", "line 1:19: mismatched input 'x'. Expecting: '(', ',', 'CROSS', 'EXCEPT', 'FETCH', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', " + - "'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', "), + "'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'PIVOT', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', "), Arguments.of( "SELECT * FROM t WHERE EXISTS (", "line 1:31: mismatched input ''. Expecting: "), diff --git a/docs/src/main/sphinx/language/sql-support.md b/docs/src/main/sphinx/language/sql-support.md index 1e24e7ded92a..3fe091524f75 100644 --- a/docs/src/main/sphinx/language/sql-support.md +++ b/docs/src/main/sphinx/language/sql-support.md @@ -61,7 +61,7 @@ catalogs](/admin/properties-catalog): The following statements provide read access to data and metadata exposed by a connector accessing a data source. They are supported by all connectors: -- {doc}`/sql/select` including {doc}`/sql/match-recognize` +- {doc}`/sql/select` including {doc}`/sql/match-recognize` and {doc}`/sql/pivot` - {doc}`/sql/describe` - {doc}`/sql/show-catalogs` - {doc}`/sql/show-columns` diff --git a/docs/src/main/sphinx/sql.md b/docs/src/main/sphinx/sql.md index 03c20b99d213..c04e2688ec41 100644 --- a/docs/src/main/sphinx/sql.md +++ b/docs/src/main/sphinx/sql.md @@ -52,6 +52,7 @@ sql/grant-roles sql/insert sql/match-recognize sql/merge +sql/pivot sql/prepare sql/refresh-materialized-view sql/reset-session diff --git a/docs/src/main/sphinx/sql/pivot.md b/docs/src/main/sphinx/sql/pivot.md new file mode 100644 index 000000000000..d1d235dfedb0 --- /dev/null +++ b/docs/src/main/sphinx/sql/pivot.md @@ -0,0 +1,175 @@ +# PIVOT + +## Synopsis + +```text +PIVOT ( + aggregation [ [ AS ] aggregation_alias ] [, ...] + FOR pivot_column [, (pivot_column [, ...]) ] IN ( pivot_value_group [, ...] ) + [ GROUP BY grouping_element [, ...] ] + ) +``` + +where `pivot_value_group` is one of + +```text +expression [ [ AS ] value_alias ] +( expression, expression [, ...] ) [ [ AS ] value_alias ] +``` + +## Description + +The `PIVOT` clause is an optional subclause of the `FROM` clause. It rotates +rows of an input relation into output columns by partitioning the rows on +one or more *pivot columns* and computing one or more *aggregations* for +each *pivot value*. The input to a pivot is a table, a view, or a subquery. +The output of a pivot is a relation, so it can itself appear in a `FROM` +clause, be aliased, or be the input to another `PIVOT`. + +`PIVOT` is useful when each row in the input represents one observation +along a categorical dimension (such as a month, region, or status), and the +report should display one column per category. Common use cases include: + +- summarizing measurements by time bucket, +- producing a column per status or category, +- comparing aggregated metrics side by side without writing repetitive + `CASE` or `FILTER` expressions. + +## Example + +In the following example, `sales` records one row per region/month, and +`PIVOT` produces one column per month within each region: + +```sql +SELECT * +FROM sales PIVOT ( + sum(amount) AS total + FOR month IN (1 AS jan, 2 AS feb, 3 AS mar) + GROUP BY region + ) +``` + +The output has columns `region`, `jan_total`, `feb_total`, `mar_total`. + +In the following sections, all subclauses of the `PIVOT` clause are +explained. + +## Aggregations + +```sql +sum(amount) AS total +``` + +Each aggregation is an expression that contains one or more aggregate +function calls. The expression is evaluated once per pivot value, with each +aggregate scoped to the rows that match that pivot value. + +The aggregation alias becomes part of the output column names (see +[](pivot-output)). Aliases are optional in the single-aggregation case but +required when a `PIVOT` declares more than one aggregation. Without +aliases, every output column for the multi-aggregation case would collide. + +`PIVOT` accepts any expression Trino accepts as an aggregating select +item. For example, the following all work: + +```sql +sum(amount) -- single aggregate +avg(amount) * 100 AS pct -- expression over an aggregate +sum(amount) - sum(refund) AS net -- multiple aggregates in one slot +``` + +When the slot expression contains multiple aggregate calls, the pivot +filter is applied to each aggregate individually, so `sum(amount) - +sum(refund)` filters both `sum`s to the rows for the current pivot value. + +## Pivot column and IN list + +```sql +FOR month IN (1 AS jan, 2 AS feb, 3 AS mar) +``` + +The `FOR` clause names the pivot column (or, for compound keys, a +parenthesised list of pivot columns) and the `IN` clause supplies the +values that become output columns. Each value is a constant expression and +is coerced to the pivot column's type using the standard implicit +coercion rules. + +For multiple pivot columns, supply tuple values in matching order: + +```sql +FOR (region, month) IN (('NA', 1) AS na_jan, ('EU', 1) AS eu_jan) +``` + +Each tuple must have the same arity as the pivot column list. + +The value alias controls the output column name for that value. It is +strongly recommended in practice — without it, the column name is +derived from the SQL text of the value expression (so `1` and `'1'` +become distinct columns named `1` and `'1'`). + +`NULL` is a permitted value, but it is treated using Trino's standard +`=` semantics: the predicate `pivot_column = NULL` is `UNKNOWN`, so the +corresponding output column always carries the empty-input aggregation +result (`NULL` for `sum`, `0` for `count`, and so on). To produce a +column for rows where the pivot column is `NULL`, supply that bucket +explicitly in the source relation rather than relying on a `NULL` IN +value. + +## GROUP BY + +```sql +GROUP BY region +``` + +The optional `GROUP BY` clause inside `PIVOT` controls which dimensions +are preserved as additional output columns. It accepts the same forms as +a top-level {doc}`GROUP BY