diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 679b2e3c0fa..db496a01884 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -17,8 +17,6 @@ limitations under the License. package planbuilder import ( - "fmt" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -28,9 +26,8 @@ import ( type ( Horizon struct { - aggregateFuncs []AnalysedAliasedExpr - projections []AnalysedAliasedExpr - hasStar bool + projektioner []AnalysedAliasedExpr + hasStar, hasAggr bool } AnalysedExpr struct { pullouts []*pulloutSubquery @@ -41,9 +38,25 @@ type ( pullouts []*pulloutSubquery origin logicalPlan expr *sqlparser.AliasedExpr + aggr bool } ) +func (h *Horizon) HasAggregation() bool { + return h.hasAggr +} +func (h *Horizon) AddProjection(e AnalysedAliasedExpr) { + if isAggregateExpression(e.expr.Expr) { + h.hasAggr = true + } + + h.projektioner = append(h.projektioner, e) +} + +// createColumnsFor creates column expressions to replace `*` expressions. If a single table is provided, +// the expansion will create columns without any column qualifiers, but if multiple tables are listed in +// the FROM clause, the column expressions will be of the type `tabl.col as col`, so the query doesn't +// accidentally become ambiguous func createColumnsFor(tables []*table) sqlparser.SelectExprs { result := sqlparser.SelectExprs{} singleTable := false @@ -160,14 +173,9 @@ func (pb *primitiveBuilder) analyseSelectExpr(sel *sqlparser.Select) (*Horizon, origin: origin, expr: node, } - if isAggregateExpression(expr) { - result.aggregateFuncs = append(result.aggregateFuncs, analysedExpr) - } else { - result.projections = append(result.projections, analysedExpr) - } - + result.AddProjection(analysedExpr) default: - return nil, fmt.Errorf("BUG: unexpected select expression type: %T", node) + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: unexpected select expression type: %T", node) } } return result, nil @@ -184,50 +192,49 @@ func isAggregateExpression(expr sqlparser.Expr) bool { func (pb *primitiveBuilder) planHorizon(sel *sqlparser.Select, horizon *Horizon) error { rb, isRoute := pb.plan.(*route) - if isRoute && len(horizon.aggregateFuncs) == 0 { - // since we can push down all of the aggregation to the route, + if isRoute && + (!horizon.HasAggregation() || rb.isSingleShard()) { // we don't need to do anything else here rb.Select = sel return nil } - resultColumns := make([]*resultColumn, 0, len(horizon.projections)+len(horizon.aggregateFuncs)) - for _, projection := range horizon.projections { - rc, _, err := pb.pushProjection(pb.plan, projection.expr, projection.origin) - if err != nil { - return err + var aggrPlan *orderedAggregate + if horizon.HasAggregation() { + eaggr := &engine.OrderedAggregate{} + aggrPlan = &orderedAggregate{ + resultsBuilder: newResultsBuilder(rb, eaggr), + eaggr: eaggr, } - resultColumns = append(resultColumns, rc) + pb.plan = aggrPlan } - if len(horizon.aggregateFuncs) > 0 { - newColumns, err := pb.planAggregation(rb, horizon) - if err != nil { - return err + + resultColumns := make([]*resultColumn, 0, len(horizon.projektioner)) + for _, projection := range horizon.projektioner { + expr := projection.expr.Expr + if isAggregateExpression(expr) { + rc, _, err := aggrPlan.pushAggr2(pb, projection.expr, projection.origin) + if err != nil { + return err + } + resultColumns = append(resultColumns, rc) + } else { + // Ensure that there are no aggregates in the expression. + if nodeHasAggregates(expr) { + return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") + } + + rc, _, err := pb.pushProjection(pb.plan, projection.expr, projection.origin) + if err != nil { + return err + } + resultColumns = append(resultColumns, rc) } - resultColumns = append(resultColumns, newColumns...) } pb.st.SetResultColumns(resultColumns) return nil } -func (pb *primitiveBuilder) planAggregation(rb *route, horizon *Horizon) ([]*resultColumn, error) { - var resultColumns []*resultColumn - eaggr := &engine.OrderedAggregate{} - newPlan := &orderedAggregate{ - resultsBuilder: newResultsBuilder(rb, eaggr), - eaggr: eaggr, - } - pb.plan = newPlan - for _, aggrFunc := range horizon.aggregateFuncs { - rc, _, err := newPlan.pushAggr2(pb, aggrFunc.expr, aggrFunc.origin) - if err != nil { - return nil, err - } - resultColumns = append(resultColumns, rc) - } - return resultColumns, nil -} - func (pb *primitiveBuilder) pushProjection(in logicalPlan, expr *sqlparser.AliasedExpr, origin logicalPlan) (*resultColumn, int, error) { switch node := in.(type) { case *join: diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index 62e472678b0..e500b9cf2e1 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "strconv" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -238,7 +240,7 @@ func (oa *orderedAggregate) pushAggr2(pb *primitiveBuilder, expr *sqlparser.Alia funcExpr := expr.Expr.(*sqlparser.FuncExpr) opcode := engine.SupportedAggregates[funcExpr.Name.Lowered()] if len(funcExpr.Exprs) != 1 { - return nil, 0, fmt.Errorf("unsupported: only one expression allowed inside aggregates: %s", sqlparser.String(funcExpr)) + return nil, 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: only one expression allowed inside aggregates: %s", sqlparser.String(funcExpr)) } handleDistinct, _, err := oa.needDistinctHandling(pb, funcExpr, opcode) if err != nil {