Skip to content

Commit

Permalink
Merge pull request #8479 from planetscale/gen4-distinct
Browse files Browse the repository at this point in the history
Gen4: Distinct
  • Loading branch information
systay authored Jul 19, 2021
2 parents ed78318 + 548a08a commit 75bf629
Show file tree
Hide file tree
Showing 17 changed files with 688 additions and 297 deletions.
1 change: 1 addition & 0 deletions go/mysql/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ var stateToMysqlCode = map[vterrors.State]struct {
vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns},
vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError},
vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError},
vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError},
vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError},
vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction},
vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError},
Expand Down
21 changes: 13 additions & 8 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1379,20 +1379,25 @@ func ToString(exprs []TableExpr) string {
}

// ContainsAggregation returns true if the expression contains aggregation
func ContainsAggregation(e Expr) bool {
func ContainsAggregation(e SQLNode) bool {
hasAggregates := false
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *FuncExpr:
if node.IsAggregate() {
hasAggregates = true
return false, nil
}
case *GroupConcatExpr:
if IsAggregation(node) {
hasAggregates = true
return false, nil
}
return true, nil
}, e)
return hasAggregates
}

// IsAggregation returns true if the node is an aggregation expression
func IsAggregation(node SQLNode) bool {
switch node := node.(type) {
case *FuncExpr:
return node.IsAggregate()
case *GroupConcatExpr:
return true
}
return false
}
1 change: 1 addition & 0 deletions go/vt/vterrors/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
NonUniqTable
NonUpdateableTable
SyntaxError
WrongFieldWithGroup
WrongGroupField
WrongTypeForVar
WrongValueForVar
Expand Down
111 changes: 87 additions & 24 deletions go/vt/vtgate/planbuilder/abstract/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type (
// If you change the contents here, please update the toString() method
SelectExprs []SelectExpr
HasAggr bool
Distinct bool
GroupByExprs []GroupBy
OrderExprs []OrderBy
CanPushDownSorting bool
Expand All @@ -58,39 +59,28 @@ type (

// CreateQPFromSelect created the QueryProjection for the input *sqlparser.Select
func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
qp := &QueryProjection{}
qp := &QueryProjection{
Distinct: sel.Distinct,
}

hasNonAggr := false
for _, selExp := range sel.SelectExprs {
exp, ok := selExp.(*sqlparser.AliasedExpr)
if !ok {
return nil, semantics.Gen4NotSupportedF("%T in select list", selExp)
}
fExpr, ok := exp.Expr.(*sqlparser.FuncExpr)
if ok && fExpr.IsAggregate() {
if len(fExpr.Exprs) != 1 {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr))
}
if fExpr.Distinct {
return nil, semantics.Gen4NotSupportedF("distinct aggregation")
}
qp.HasAggr = true
qp.SelectExprs = append(qp.SelectExprs, SelectExpr{
Col: exp,
Aggr: true,
})
continue
}

if err := checkForInvalidAggregations(exp); err != nil {
return nil, err
}
col := SelectExpr{
Col: exp,
}
if sqlparser.ContainsAggregation(exp.Expr) {
return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression")
col.Aggr = true
qp.HasAggr = true
}
hasNonAggr = true
qp.SelectExprs = append(qp.SelectExprs, SelectExpr{Col: exp})
}

if hasNonAggr && qp.HasAggr && sel.GroupBy == nil {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.MixOfGroupFuncAndFields, "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause")
qp.SelectExprs = append(qp.SelectExprs, col)
}

for _, group := range sel.GroupBy {
Expand All @@ -104,6 +94,18 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
qp.GroupByExprs = append(qp.GroupByExprs, GroupBy{Inner: expr, WeightStrExpr: weightStrExpr})
}

if qp.HasAggr {
expr := qp.getNonAggrExprNotMatchingGroupByExprs()
// if we have aggregation functions, non aggregating columns and GROUP BY,
// the non-aggregating expressions must all be listed in the GROUP BY list
if expr != nil {
if len(qp.GroupByExprs) > 0 {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongFieldWithGroup, "Expression of SELECT list is not in GROUP BY clause and contains nonaggregated column '%s' which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_mode=only_full_group_by", sqlparser.String(expr))
}
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.MixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by", sqlparser.String(expr))
}
}

canPushDownSorting := true
for _, order := range sel.OrderBy {
expr, weightStrExpr, err := qp.getSimplifiedExpr(order.Expr, "order clause")
Expand All @@ -119,11 +121,47 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
})
canPushDownSorting = canPushDownSorting && !sqlparser.ContainsAggregation(weightStrExpr)
}
if qp.Distinct && !qp.HasAggr {
qp.GroupByExprs = nil
}
qp.CanPushDownSorting = canPushDownSorting

return qp, nil
}

func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error {
return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
fExpr, ok := node.(*sqlparser.FuncExpr)
if ok && fExpr.IsAggregate() {
if len(fExpr.Exprs) != 1 {
return false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr))
}
if fExpr.Distinct {
return false, semantics.Gen4NotSupportedF("distinct aggregation")
}
}
return true, nil
}, exp.Expr)
}

func (qp *QueryProjection) getNonAggrExprNotMatchingGroupByExprs() sqlparser.Expr {
for _, expr := range qp.SelectExprs {
if expr.Aggr {
continue
}
isGroupByOk := false
for _, groupByExpr := range qp.GroupByExprs {
if sqlparser.EqualsExpr(groupByExpr.WeightStrExpr, expr.Col.Expr) {
isGroupByOk = true
break
}
}
if !isGroupByOk {
return expr.Col.Expr
}
}
return nil
}

// getSimplifiedExpr takes an expression used in ORDER BY or GROUP BY, which can reference both aliased columns and
// column offsets, and returns an expression that is simpler to evaluate
func (qp *QueryProjection) getSimplifiedExpr(e sqlparser.Expr, caller string) (expr sqlparser.Expr, weightStrExpr sqlparser.Expr, err error) {
Expand Down Expand Up @@ -169,11 +207,13 @@ func (qp *QueryProjection) toString() string {
Select []string
Grouping []string
OrderBy []string
Distinct bool
}
out := output{
Select: []string{},
Grouping: []string{},
OrderBy: []string{},
Distinct: qp.NeedsDistinct(),
}

for _, expr := range qp.SelectExprs {
Expand Down Expand Up @@ -204,3 +244,26 @@ func (qp *QueryProjection) toString() string {
func (qp *QueryProjection) NeedsAggregation() bool {
return qp.HasAggr || len(qp.GroupByExprs) > 0
}

func (qp QueryProjection) onlyAggr() bool {
if !qp.HasAggr {
return false
}
for _, expr := range qp.SelectExprs {
if !expr.Aggr {
return false
}
}
return true
}

// NeedsDistinct returns true if the query needs explicit distinct
func (qp *QueryProjection) NeedsDistinct() bool {
if !qp.Distinct {
return false
}
if qp.onlyAggr() && len(qp.GroupByExprs) == 0 {
return false
}
return true
}
42 changes: 33 additions & 9 deletions go/vt/vtgate/planbuilder/abstract/queryprojection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestQP(t *testing.T) {
},
{
sql: "select 1, count(1) from user",
expErr: "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause",
expErr: "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '1'; this is incompatible with sql_mode=only_full_group_by",
},
{
sql: "select max(id) from user",
Expand All @@ -55,13 +55,9 @@ func TestQP(t *testing.T) {
sql: "select max(a, b) from user",
expErr: "aggregate functions take a single argument 'max(a, b)'",
},
{
sql: "select func(max(id)) from user",
expErr: "unsupported: in scatter query: complex aggregate expression",
},
{
sql: "select 1, count(1) from user order by 1",
expErr: "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause",
expErr: "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '1'; this is incompatible with sql_mode=only_full_group_by",
},
{
sql: "select id from user order by col, id, 1",
Expand Down Expand Up @@ -133,7 +129,8 @@ func TestQPSimplifiedExpr(t *testing.T) {
"Grouping": [
"intcol"
],
"OrderBy": []
"OrderBy": [],
"Distinct": false
}`,
},
{
Expand All @@ -148,7 +145,8 @@ func TestQPSimplifiedExpr(t *testing.T) {
"OrderBy": [
"intcol asc",
"textcol asc"
]
],
"Distinct": false
}`,
},
{
Expand All @@ -167,7 +165,33 @@ func TestQPSimplifiedExpr(t *testing.T) {
],
"OrderBy": [
"textcol desc"
]
],
"Distinct": false
}`,
},
{
query: "select distinct col1, col2 from user group by col1, col2",
expected: `
{
"Select": [
"col1",
"col2"
],
"Grouping": [],
"OrderBy": [],
"Distinct": true
}`,
},
{
query: "select distinct count(*) from user",
expected: `
{
"Select": [
"aggr: count(*)"
],
"Grouping": [],
"OrderBy": [],
"Distinct": false
}`,
},
}
Expand Down
6 changes: 1 addition & 5 deletions go/vt/vtgate/planbuilder/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ import (

var _ logicalPlan = (*distinct)(nil)

// limit is the logicalPlan for engine.Limit.
// This gets built if a limit needs to be applied
// after rows are returned from an underlying
// operation. Since a limit is the final operation
// of a SELECT, most pushes are not applicable.
// distinct is the logicalPlan for engine.Distinct.
type distinct struct {
logicalPlanCommon
}
Expand Down
Loading

0 comments on commit 75bf629

Please sign in to comment.