From 4ab52e4a5e2abd31d7b28aa506dd788e11f62091 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 12 Apr 2023 21:47:16 -0700 Subject: [PATCH] Remove unnecessary Cube and Rollup classes GroupingElement, Cube and Rollup contain the same information and look very similar. Replace them with a single class and and enum to indicate the type of grouping sets clause they represent. --- .../trino/sql/analyzer/StatementAnalyzer.java | 72 +++++--------- .../io/trino/sql/ExpressionFormatter.java | 29 +++--- .../java/io/trino/sql/parser/AstBuilder.java | 11 ++- .../java/io/trino/sql/tree/AstVisitor.java | 10 -- .../src/main/java/io/trino/sql/tree/Cube.java | 98 ------------------- .../sql/tree/DefaultTraversalVisitor.java | 12 --- .../java/io/trino/sql/tree/GroupingSets.java | 33 +++++-- .../main/java/io/trino/sql/tree/Rollup.java | 98 ------------------- .../io/trino/sql/parser/TestSqlParser.java | 14 +-- 9 files changed, 79 insertions(+), 298 deletions(-) delete mode 100644 core/trino-parser/src/main/java/io/trino/sql/tree/Cube.java delete mode 100644 core/trino-parser/src/main/java/io/trino/sql/tree/Rollup.java diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index d03830ce035d..419fd8da0a7f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -136,7 +136,6 @@ import io.trino.sql.tree.CreateTable; import io.trino.sql.tree.CreateTableAsSelect; import io.trino.sql.tree.CreateView; -import io.trino.sql.tree.Cube; import io.trino.sql.tree.Deallocate; import io.trino.sql.tree.Delete; import io.trino.sql.tree.Deny; @@ -203,7 +202,6 @@ import io.trino.sql.tree.ResetSession; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.Rollback; -import io.trino.sql.tree.Rollup; import io.trino.sql.tree.Row; import io.trino.sql.tree.RowPattern; import io.trino.sql.tree.SampledRelation; @@ -3975,18 +3973,18 @@ private void checkGroupingSetsCount(GroupBy node) if (element instanceof SimpleGroupBy) { product = 1; } - else if (element instanceof Cube) { - int exponent = element.getExpressions().size(); - if (exponent > 30) { - throw new ArithmeticException(); - } - product = 1 << exponent; - } - else if (element instanceof Rollup) { - product = element.getExpressions().size() + 1; - } - else if (element instanceof GroupingSets) { - product = ((GroupingSets) element).getSets().size(); + else if (element instanceof GroupingSets groupingSets) { + product = switch (groupingSets.getType()) { + case CUBE -> { + int exponent = ((GroupingSets) element).getSets().size(); + if (exponent > 30) { + throw new ArithmeticException(); + } + yield 1 << exponent; + } + case ROLLUP -> groupingSets.getSets().size() + 1; + case EXPLICIT -> groupingSets.getSets().size(); + }; } else { throw new UnsupportedOperationException("Unsupported grouping element type: " + element.getClass().getName()); @@ -4044,7 +4042,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, groupingExpressions.add(column); } } - else { + else if (groupingElement instanceof GroupingSets element) { for (Expression column : groupingElement.getExpressions()) { analyzeExpression(column, scope); if (!analysis.getColumnReferences().contains(NodeRef.of(column))) { @@ -4054,38 +4052,18 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, groupingExpressions.add(column); } - if (groupingElement instanceof Cube) { - List> cube = ((Cube) groupingElement).getSets().stream() - .map(set -> set.stream() - .map(NodeRef::of) - .map(analysis.getColumnReferenceFields()::get) - .map(ResolvedField::getFieldId) - .collect(toImmutableSet())) - .collect(toImmutableList()); - - cubes.add(cube); - } - else if (groupingElement instanceof Rollup) { - List> rollup = ((Rollup) groupingElement).getSets().stream() - .map(set -> set.stream() - .map(NodeRef::of) - .map(analysis.getColumnReferenceFields()::get) - .map(ResolvedField::getFieldId) - .collect(toImmutableSet())) - .collect(toImmutableList()); - - rollups.add(rollup); - } - else if (groupingElement instanceof GroupingSets) { - List> groupingSets = ((GroupingSets) groupingElement).getSets().stream() - .map(set -> set.stream() - .map(NodeRef::of) - .map(analysis.getColumnReferenceFields()::get) - .map(ResolvedField::getFieldId) - .collect(toImmutableSet())) - .collect(toImmutableList()); - - sets.add(groupingSets); + List> groupingSets = element.getSets().stream() + .map(set -> set.stream() + .map(NodeRef::of) + .map(analysis.getColumnReferenceFields()::get) + .map(ResolvedField::getFieldId) + .collect(toImmutableSet())) + .collect(toImmutableList()); + + switch (element.getType()) { + case CUBE -> cubes.add(groupingSets); + case ROLLUP -> rollups.add(groupingSets); + case EXPLICIT -> sets.add(groupingSets); } } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index e2738cfc04fe..7258e0200ec6 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -31,7 +31,6 @@ import io.trino.sql.tree.CharLiteral; import io.trino.sql.tree.CoalesceExpression; import io.trino.sql.tree.ComparisonExpression; -import io.trino.sql.tree.Cube; import io.trino.sql.tree.CurrentCatalog; import io.trino.sql.tree.CurrentPath; import io.trino.sql.tree.CurrentSchema; @@ -84,7 +83,6 @@ import io.trino.sql.tree.Parameter; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.QuantifiedComparisonExpression; -import io.trino.sql.tree.Rollup; import io.trino.sql.tree.Row; import io.trino.sql.tree.RowDataType; import io.trino.sql.tree.SearchedCaseExpression; @@ -1230,19 +1228,24 @@ static String formatGroupBy(List groupingElements) } } else if (groupingElement instanceof GroupingSets) { + String type; + switch (((GroupingSets) groupingElement).getType()) { + case EXPLICIT: + type = "GROUPING SETS"; + break; + case CUBE: + type = "CUBE"; + break; + case ROLLUP: + type = "ROLLUP"; + break; + default: + throw new UnsupportedOperationException(); + } + result = ((GroupingSets) groupingElement).getSets().stream() .map(ExpressionFormatter::formatGroupingSet) - .collect(joining(", ", "GROUPING SETS (", ")")); - } - else if (groupingElement instanceof Cube) { - result = ((Cube) groupingElement).getSets().stream() - .map(ExpressionFormatter::formatGroupingSet) - .collect(joining(", ", "CUBE (", ")")); - } - else if (groupingElement instanceof Rollup) { - result = ((Rollup) groupingElement).getSets().stream() - .map(ExpressionFormatter::formatGroupingSet) - .collect(joining(", ", "ROLLUP (", ")")); + .collect(joining(", ", type + " (", ")")); } return result; }) diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 32d2e5b030fb..670980e65dcf 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -49,7 +49,6 @@ import io.trino.sql.tree.CreateTable; import io.trino.sql.tree.CreateTableAsSelect; import io.trino.sql.tree.CreateView; -import io.trino.sql.tree.Cube; import io.trino.sql.tree.CurrentCatalog; import io.trino.sql.tree.CurrentPath; import io.trino.sql.tree.CurrentSchema; @@ -187,7 +186,6 @@ import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; -import io.trino.sql.tree.Rollup; import io.trino.sql.tree.Row; import io.trino.sql.tree.RowDataType; import io.trino.sql.tree.RowPattern; @@ -278,6 +276,9 @@ import static io.trino.sql.parser.SqlBaseParser.TIMESTAMP; import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_END; import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_START; +import static io.trino.sql.tree.GroupingSets.Type.CUBE; +import static io.trino.sql.tree.GroupingSets.Type.EXPLICIT; +import static io.trino.sql.tree.GroupingSets.Type.ROLLUP; import static io.trino.sql.tree.JsonExists.ErrorBehavior.ERROR; import static io.trino.sql.tree.JsonExists.ErrorBehavior.FALSE; import static io.trino.sql.tree.JsonExists.ErrorBehavior.TRUE; @@ -1145,7 +1146,7 @@ public Node visitSingleGroupingSet(SqlBaseParser.SingleGroupingSetContext contex @Override public Node visitRollup(SqlBaseParser.RollupContext context) { - return new Rollup(getLocation(context), context.groupingSet().stream() + return new GroupingSets(getLocation(context), ROLLUP, context.groupingSet().stream() .map(groupingSet -> visit(groupingSet.expression(), Expression.class)) .collect(toList())); } @@ -1153,7 +1154,7 @@ public Node visitRollup(SqlBaseParser.RollupContext context) @Override public Node visitCube(SqlBaseParser.CubeContext context) { - return new Cube(getLocation(context), context.groupingSet().stream() + return new GroupingSets(getLocation(context), CUBE, context.groupingSet().stream() .map(groupingSet -> visit(groupingSet.expression(), Expression.class)) .collect(toList())); } @@ -1161,7 +1162,7 @@ public Node visitCube(SqlBaseParser.CubeContext context) @Override public Node visitMultipleGroupingSets(SqlBaseParser.MultipleGroupingSetsContext context) { - return new GroupingSets(getLocation(context), context.groupingSet().stream() + return new GroupingSets(getLocation(context), EXPLICIT, context.groupingSet().stream() .map(groupingSet -> visit(groupingSet.expression(), Expression.class)) .collect(toList())); } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index eccac7a01c75..4346f02794f9 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -862,21 +862,11 @@ protected R visitGroupingElement(GroupingElement node, C context) return visitNode(node, context); } - protected R visitCube(Cube node, C context) - { - return visitGroupingElement(node, context); - } - protected R visitGroupingSets(GroupingSets node, C context) { return visitGroupingElement(node, context); } - protected R visitRollup(Rollup node, C context) - { - return visitGroupingElement(node, context); - } - protected R visitSimpleGroupBy(SimpleGroupBy node, C context) { return visitGroupingElement(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Cube.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Cube.java deleted file mode 100644 index 2e16503c78a0..000000000000 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Cube.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.tree; - -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static java.util.Objects.requireNonNull; - -public final class Cube - extends GroupingElement -{ - private final List> sets; - - public Cube(List> sets) - { - this(Optional.empty(), sets); - } - - public Cube(NodeLocation location, List> sets) - { - this(Optional.of(location), sets); - } - - private Cube(Optional location, List> sets) - { - super(location); - this.sets = ImmutableList.copyOf(requireNonNull(sets, "sets is null")); - } - - public List> getSets() - { - return sets; - } - - @Override - public List getExpressions() - { - return sets.stream() - .flatMap(List::stream) - .collect(Collectors.toList()); - } - - @Override - protected R accept(AstVisitor visitor, C context) - { - return visitor.visitCube(this, context); - } - - @Override - public List getChildren() - { - return ImmutableList.of(); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Cube cube = (Cube) o; - return Objects.equals(sets, cube.sets); - } - - @Override - public int hashCode() - { - return Objects.hash(sets); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("sets", sets) - .toString(); - } -} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index 9b2857ee4995..c92efc9ffd57 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -606,18 +606,6 @@ protected Void visitGroupBy(GroupBy node, C context) return null; } - @Override - protected Void visitCube(Cube node, C context) - { - return null; - } - - @Override - protected Void visitRollup(Rollup node, C context) - { - return null; - } - @Override protected Void visitSimpleGroupBy(SimpleGroupBy node, C context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/GroupingSets.java b/core/trino-parser/src/main/java/io/trino/sql/tree/GroupingSets.java index 823f28ee5341..13b2a44d0227 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/GroupingSets.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/GroupingSets.java @@ -28,26 +28,40 @@ public final class GroupingSets extends GroupingElement { + public enum Type + { + EXPLICIT, + ROLLUP, + CUBE + } + + private final Type type; private final List> sets; - public GroupingSets(List> groupingSets) + public GroupingSets(Type type, List> groupingSets) { - this(Optional.empty(), groupingSets); + this(Optional.empty(), type, groupingSets); } - public GroupingSets(NodeLocation location, List> sets) + public GroupingSets(NodeLocation location, Type type, List> sets) { - this(Optional.of(location), sets); + this(Optional.of(location), type, sets); } - private GroupingSets(Optional location, List> sets) + private GroupingSets(Optional location, Type type, List> sets) { super(location); + this.type = requireNonNull(type, "type is null"); requireNonNull(sets, "sets is null"); checkArgument(!sets.isEmpty(), "grouping sets cannot be empty"); this.sets = sets.stream().map(ImmutableList::copyOf).collect(toImmutableList()); } + public Type getType() + { + return type; + } + public List> getSets() { return sets; @@ -82,20 +96,21 @@ public boolean equals(Object o) if (o == null || getClass() != o.getClass()) { return false; } - GroupingSets groupingSets = (GroupingSets) o; - return Objects.equals(sets, groupingSets.sets); + GroupingSets that = (GroupingSets) o; + return type == that.type && sets.equals(that.sets); } @Override public int hashCode() { - return Objects.hash(sets); + return Objects.hash(type, sets); } @Override public String toString() { return toStringHelper(this) + .add("type", type) .add("sets", sets) .toString(); } @@ -108,6 +123,6 @@ public boolean shallowEquals(Node other) } GroupingSets that = (GroupingSets) other; - return Objects.equals(sets, that.sets); + return Objects.equals(sets, that.sets) && Objects.equals(type, that.type); } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Rollup.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Rollup.java deleted file mode 100644 index 812265fa84f3..000000000000 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Rollup.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.tree; - -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static java.util.Objects.requireNonNull; - -public final class Rollup - extends GroupingElement -{ - private final List> sets; - - public Rollup(List> sets) - { - this(Optional.empty(), sets); - } - - public Rollup(NodeLocation location, List> sets) - { - this(Optional.of(location), sets); - } - - private Rollup(Optional location, List> sets) - { - super(location); - this.sets = ImmutableList.copyOf(requireNonNull(sets, "sets is null")); - } - - public List> getSets() - { - return sets; - } - - @Override - public List getExpressions() - { - return sets.stream() - .flatMap(List::stream) - .collect(Collectors.toList()); - } - - @Override - protected R accept(AstVisitor visitor, C context) - { - return visitor.visitRollup(this, context); - } - - @Override - public List getChildren() - { - return ImmutableList.of(); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Rollup rollup = (Rollup) o; - return Objects.equals(sets, rollup.sets); - } - - @Override - public int hashCode() - { - return Objects.hash(sets); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("sets", sets) - .toString(); - } -} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 73b96b5f49e5..fef934e78a75 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -45,7 +45,6 @@ import io.trino.sql.tree.CreateTable; import io.trino.sql.tree.CreateTableAsSelect; import io.trino.sql.tree.CreateView; -import io.trino.sql.tree.Cube; import io.trino.sql.tree.CurrentTime; import io.trino.sql.tree.Deallocate; import io.trino.sql.tree.DecimalLiteral; @@ -158,7 +157,6 @@ import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; -import io.trino.sql.tree.Rollup; import io.trino.sql.tree.Row; import io.trino.sql.tree.SearchedCaseExpression; import io.trino.sql.tree.Select; @@ -1601,6 +1599,7 @@ public void testSelectWithGroupBy() new Table(QualifiedName.of("table1")), Optional.empty(), Optional.of(new GroupBy(false, ImmutableList.of(new GroupingSets( + GroupingSets.Type.EXPLICIT, ImmutableList.of( ImmutableList.of(new Identifier("a"))))))), Optional.empty(), @@ -1619,6 +1618,7 @@ public void testSelectWithGroupBy() new Table(QualifiedName.of("table1")), Optional.empty(), Optional.of(new GroupBy(false, ImmutableList.of(new GroupingSets( + GroupingSets.Type.EXPLICIT, ImmutableList.of( ImmutableList.of(new Identifier("a")), ImmutableList.of(new Identifier("b"))))))), @@ -1634,12 +1634,13 @@ public void testSelectWithGroupBy() Optional.empty(), Optional.of(new GroupBy(false, ImmutableList.of( new GroupingSets( + GroupingSets.Type.EXPLICIT, ImmutableList.of( ImmutableList.of(new Identifier("a"), new Identifier("b")), ImmutableList.of(new Identifier("a")), ImmutableList.of())), - new Cube(ImmutableList.of(ImmutableList.of(new Identifier("c")))), - new Rollup(ImmutableList.of(ImmutableList.of(new Identifier("d"))))))), + new GroupingSets(GroupingSets.Type.CUBE, ImmutableList.of(ImmutableList.of(new Identifier("c")))), + new GroupingSets(GroupingSets.Type.ROLLUP, ImmutableList.of(ImmutableList.of(new Identifier("d"))))))), Optional.empty(), Optional.empty(), Optional.empty(), @@ -1652,12 +1653,13 @@ public void testSelectWithGroupBy() Optional.empty(), Optional.of(new GroupBy(true, ImmutableList.of( new GroupingSets( + GroupingSets.Type.EXPLICIT, ImmutableList.of( ImmutableList.of(new Identifier("a"), new Identifier("b")), ImmutableList.of(new Identifier("a")), ImmutableList.of())), - new Cube(ImmutableList.of(ImmutableList.of(new Identifier("c")))), - new Rollup(ImmutableList.of(ImmutableList.of(new Identifier("d"))))))), + new GroupingSets(GroupingSets.Type.CUBE, ImmutableList.of(ImmutableList.of(new Identifier("c")))), + new GroupingSets(GroupingSets.Type.ROLLUP, ImmutableList.of(ImmutableList.of(new Identifier("d"))))))), Optional.empty(), Optional.empty(), Optional.empty(),