Skip to content

Commit

Permalink
Remove unnecessary Cube and Rollup classes
Browse files Browse the repository at this point in the history
GroupingElement, Cube and Rollup contain the same information
and look very similar. Replace them with a single class and
and enum to indicate the type of grouping sets clause they
represent.
  • Loading branch information
martint committed Apr 17, 2023
1 parent 2486eca commit 4ab52e4
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 298 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.CreateView;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.Deallocate;
import io.trino.sql.tree.Delete;
import io.trino.sql.tree.Deny;
Expand Down Expand Up @@ -203,7 +202,6 @@
import io.trino.sql.tree.ResetSession;
import io.trino.sql.tree.Revoke;
import io.trino.sql.tree.Rollback;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowPattern;
import io.trino.sql.tree.SampledRelation;
Expand Down Expand Up @@ -3975,18 +3973,18 @@ private void checkGroupingSetsCount(GroupBy node)
if (element instanceof SimpleGroupBy) {
product = 1;
}
else if (element instanceof Cube) {
int exponent = element.getExpressions().size();
if (exponent > 30) {
throw new ArithmeticException();
}
product = 1 << exponent;
}
else if (element instanceof Rollup) {
product = element.getExpressions().size() + 1;
}
else if (element instanceof GroupingSets) {
product = ((GroupingSets) element).getSets().size();
else if (element instanceof GroupingSets groupingSets) {
product = switch (groupingSets.getType()) {
case CUBE -> {
int exponent = ((GroupingSets) element).getSets().size();
if (exponent > 30) {
throw new ArithmeticException();
}
yield 1 << exponent;
}
case ROLLUP -> groupingSets.getSets().size() + 1;
case EXPLICIT -> groupingSets.getSets().size();
};
}
else {
throw new UnsupportedOperationException("Unsupported grouping element type: " + element.getClass().getName());
Expand Down Expand Up @@ -4044,7 +4042,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
groupingExpressions.add(column);
}
}
else {
else if (groupingElement instanceof GroupingSets element) {
for (Expression column : groupingElement.getExpressions()) {
analyzeExpression(column, scope);
if (!analysis.getColumnReferences().contains(NodeRef.of(column))) {
Expand All @@ -4054,38 +4052,18 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
groupingExpressions.add(column);
}

if (groupingElement instanceof Cube) {
List<Set<FieldId>> cube = ((Cube) groupingElement).getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

cubes.add(cube);
}
else if (groupingElement instanceof Rollup) {
List<Set<FieldId>> rollup = ((Rollup) groupingElement).getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

rollups.add(rollup);
}
else if (groupingElement instanceof GroupingSets) {
List<Set<FieldId>> groupingSets = ((GroupingSets) groupingElement).getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

sets.add(groupingSets);
List<Set<FieldId>> groupingSets = element.getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

switch (element.getType()) {
case CUBE -> cubes.add(groupingSets);
case ROLLUP -> rollups.add(groupingSets);
case EXPLICIT -> sets.add(groupingSets);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.trino.sql.tree.CharLiteral;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
Expand Down Expand Up @@ -84,7 +83,6 @@
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.QuantifiedComparisonExpression;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowDataType;
import io.trino.sql.tree.SearchedCaseExpression;
Expand Down Expand Up @@ -1230,19 +1228,24 @@ static String formatGroupBy(List<GroupingElement> groupingElements)
}
}
else if (groupingElement instanceof GroupingSets) {
String type;
switch (((GroupingSets) groupingElement).getType()) {
case EXPLICIT:
type = "GROUPING SETS";
break;
case CUBE:
type = "CUBE";
break;
case ROLLUP:
type = "ROLLUP";
break;
default:
throw new UnsupportedOperationException();
}

result = ((GroupingSets) groupingElement).getSets().stream()
.map(ExpressionFormatter::formatGroupingSet)
.collect(joining(", ", "GROUPING SETS (", ")"));
}
else if (groupingElement instanceof Cube) {
result = ((Cube) groupingElement).getSets().stream()
.map(ExpressionFormatter::formatGroupingSet)
.collect(joining(", ", "CUBE (", ")"));
}
else if (groupingElement instanceof Rollup) {
result = ((Rollup) groupingElement).getSets().stream()
.map(ExpressionFormatter::formatGroupingSet)
.collect(joining(", ", "ROLLUP (", ")"));
.collect(joining(", ", type + " (", ")"));
}
return result;
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.CreateView;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
Expand Down Expand Up @@ -187,7 +186,6 @@
import io.trino.sql.tree.Revoke;
import io.trino.sql.tree.RevokeRoles;
import io.trino.sql.tree.Rollback;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowDataType;
import io.trino.sql.tree.RowPattern;
Expand Down Expand Up @@ -278,6 +276,9 @@
import static io.trino.sql.parser.SqlBaseParser.TIMESTAMP;
import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_END;
import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_START;
import static io.trino.sql.tree.GroupingSets.Type.CUBE;
import static io.trino.sql.tree.GroupingSets.Type.EXPLICIT;
import static io.trino.sql.tree.GroupingSets.Type.ROLLUP;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.ERROR;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.FALSE;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.TRUE;
Expand Down Expand Up @@ -1145,23 +1146,23 @@ public Node visitSingleGroupingSet(SqlBaseParser.SingleGroupingSetContext contex
@Override
public Node visitRollup(SqlBaseParser.RollupContext context)
{
return new Rollup(getLocation(context), context.groupingSet().stream()
return new GroupingSets(getLocation(context), ROLLUP, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}

@Override
public Node visitCube(SqlBaseParser.CubeContext context)
{
return new Cube(getLocation(context), context.groupingSet().stream()
return new GroupingSets(getLocation(context), CUBE, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}

@Override
public Node visitMultipleGroupingSets(SqlBaseParser.MultipleGroupingSetsContext context)
{
return new GroupingSets(getLocation(context), context.groupingSet().stream()
return new GroupingSets(getLocation(context), EXPLICIT, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}
Expand Down
10 changes: 0 additions & 10 deletions core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -862,21 +862,11 @@ protected R visitGroupingElement(GroupingElement node, C context)
return visitNode(node, context);
}

protected R visitCube(Cube node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitGroupingSets(GroupingSets node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitRollup(Rollup node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitSimpleGroupBy(SimpleGroupBy node, C context)
{
return visitGroupingElement(node, context);
Expand Down
98 changes: 0 additions & 98 deletions core/trino-parser/src/main/java/io/trino/sql/tree/Cube.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -606,18 +606,6 @@ protected Void visitGroupBy(GroupBy node, C context)
return null;
}

@Override
protected Void visitCube(Cube node, C context)
{
return null;
}

@Override
protected Void visitRollup(Rollup node, C context)
{
return null;
}

@Override
protected Void visitSimpleGroupBy(SimpleGroupBy node, C context)
{
Expand Down
Loading

0 comments on commit 4ab52e4

Please sign in to comment.