diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 491ee7d7059..85002c8bc4d 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -372,6 +372,18 @@ func (v Value) IsDateTime() bool { return int(v.typ)&dt == dt } +// IsComparable returns true if the Value is null safe comparable without collation information. +func (v *Value) IsComparable() bool { + if v.typ == Null || IsNumber(v.typ) || IsBinary(v.typ) { + return true + } + switch v.typ { + case Timestamp, Date, Time, Datetime, Enum, Set, TypeJSON, Bit: + return true + } + return false +} + // MarshalJSON should only be used for testing. // It's not a complete implementation. func (v Value) MarshalJSON() ([]byte, error) { diff --git a/go/test/endtoend/vtgate/gen4/gen4_test.go b/go/test/endtoend/vtgate/gen4/gen4_test.go index 34b463fec10..8359044210c 100644 --- a/go/test/endtoend/vtgate/gen4/gen4_test.go +++ b/go/test/endtoend/vtgate/gen4/gen4_test.go @@ -86,6 +86,38 @@ func TestGroupBy(t *testing.T) { `[INT64(2) VARCHAR("B") VARCHAR("C") VARCHAR("abc")]]`) } +func TestDistinctAggregationFunc(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + defer exec(t, conn, `delete from t2`) + + // insert some data. + checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) + + // count on primary vindex + assertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`, + `[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`) + + // count on any column + assertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`, + `[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`) + + // sum of columns + assertMatches(t, conn, `select sum(id), sum(tcol1) from t2`, + `[[DECIMAL(36) FLOAT64(0)]]`) + + // sum on primary vindex + assertMatches(t, conn, `select tcol1, sum(distinct id) from t2 group by tcol1`, + `[[VARCHAR("A") DECIMAL(9)] [VARCHAR("B") DECIMAL(15)] [VARCHAR("C") DECIMAL(12)]]`) + + // sum on any column + assertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`, + `[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`) +} + func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { t.Helper() qr := checkedExec(t, conn, query) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index acfc640f6e4..1876c9819d2 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -33,7 +33,7 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(72) } // field Alias string size += int64(len(cached.Alias)) @@ -406,18 +406,18 @@ func (cached *OrderedAggregate) CachedSize(alloc bool) int64 { if alloc { size += int64(80) } - // field Aggregates []vitess.io/vitess/go/vt/vtgate/engine.AggregateParams + // field Aggregates []*vitess.io/vitess/go/vt/vtgate/engine.AggregateParams { - size += int64(cap(cached.Aggregates)) * int64(48) + size += int64(cap(cached.Aggregates)) * int64(8) for _, elem := range cached.Aggregates { - size += elem.CachedSize(false) + size += elem.CachedSize(true) } } - // field GroupByKeys []vitess.io/vitess/go/vt/vtgate/engine.GroupByParams + // field GroupByKeys []*vitess.io/vitess/go/vt/vtgate/engine.GroupByParams { - size += int64(cap(cached.GroupByKeys)) * int64(32) + size += int64(cap(cached.GroupByKeys)) * int64(8) for _, elem := range cached.GroupByKeys { - size += elem.CachedSize(false) + size += elem.CachedSize(true) } } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 747a77db040..e78db30f25e 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -45,11 +45,11 @@ type OrderedAggregate struct { PreProcess bool `json:",omitempty"` // Aggregates specifies the aggregation parameters for each // aggregation function: function opcode and input column number. - Aggregates []AggregateParams + Aggregates []*AggregateParams // GroupByKeys specifies the input values that must be used for // the aggregation key. - GroupByKeys []GroupByParams + GroupByKeys []*GroupByParams // TruncateColumnCount specifies the number of columns to return // in the final result. Rest of the columns are truncated @@ -80,25 +80,34 @@ func (gbp GroupByParams) String() string { type AggregateParams struct { Opcode AggregateOpcode Col int + + // These are used only for distinct opcodes. + KeyCol int + WCol int + WAssigned bool // Alias is set only for distinct opcodes. Alias string `json:",omitempty"` - Expr sqlparser.Expr + + Expr sqlparser.Expr } -func (ap AggregateParams) isDistinct() bool { +func (ap *AggregateParams) isDistinct() bool { return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct } -func (ap AggregateParams) preProcess() bool { +func (ap *AggregateParams) preProcess() bool { return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct || ap.Opcode == AggregateGtid } -func (ap AggregateParams) String() string { +func (ap *AggregateParams) String() string { + keyCol := strconv.Itoa(ap.Col) + if ap.WAssigned { + keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) + } if ap.Alias != "" { - return fmt.Sprintf("%s(%d) AS %s", ap.Opcode.String(), ap.Col, ap.Alias) + return fmt.Sprintf("%s(%s) AS %s", ap.Opcode.String(), keyCol, ap.Alias) } - - return fmt.Sprintf("%s(%d)", ap.Opcode.String(), ap.Col) + return fmt.Sprintf("%s(%s)", ap.Opcode.String(), keyCol) } // AggregateOpcode is the aggregation Opcode. @@ -306,6 +315,9 @@ func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) []*querypb.Fi Name: aggr.Alias, Type: opcodeType[aggr.Opcode], } + if aggr.isDistinct() { + aggr.KeyCol = aggr.Col + } } return fields } @@ -318,15 +330,15 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. for _, aggr := range oa.Aggregates { switch aggr.Opcode { case AggregateCountDistinct: - curDistinct = row[aggr.Col] + curDistinct = findComparableCurrentDistinct(row, aggr) // Type is int64. Ok to call MakeTrusted. - if row[aggr.Col].IsNull() { + if row[aggr.KeyCol].IsNull() { newRow[aggr.Col] = countZero } else { newRow[aggr.Col] = countOne } case AggregateSumDistinct: - curDistinct = row[aggr.Col] + curDistinct = findComparableCurrentDistinct(row, aggr) var err error newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) if err != nil { @@ -347,6 +359,15 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. return newRow, curDistinct } +func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value { + curDistinct := row[aggr.KeyCol] + if aggr.WAssigned && !curDistinct.IsComparable() { + aggr.KeyCol = aggr.WCol + curDistinct = row[aggr.KeyCol] + } + return curDistinct +} + // GetFields is a Primitive function. func (oa *OrderedAggregate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { qr, err := oa.Input.GetFields(vcursor, bindVars) @@ -392,17 +413,17 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes result := sqltypes.CopyRow(row1) for _, aggr := range oa.Aggregates { if aggr.isDistinct() { - if row2[aggr.Col].IsNull() { + if row2[aggr.KeyCol].IsNull() { continue } - cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.Col]) + cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.KeyCol]) if err != nil { return nil, sqltypes.NULL, err } if cmp == 0 { continue } - curDistinct = row2[aggr.Col] + curDistinct = findComparableCurrentDistinct(row2, aggr) } var err error switch aggr.Opcode { @@ -473,11 +494,11 @@ func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) { } func aggregateParamsToString(in interface{}) string { - return in.(AggregateParams).String() + return in.(*AggregateParams).String() } func groupByParamsToString(i interface{}) string { - return i.(GroupByParams).String() + return i.(*GroupByParams).String() } func (oa *OrderedAggregate) description() PrimitiveDescription { diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index b471b94a580..a32f50b5e29 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -49,11 +49,11 @@ func TestOrderedAggregateExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -86,11 +86,11 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 2}}, + GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -128,11 +128,11 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -171,11 +171,11 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 2}}, + GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -307,7 +307,7 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCountDistinct, Col: 1, Alias: "count(distinct col2)", @@ -316,7 +316,7 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -383,7 +383,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCountDistinct, Col: 1, Alias: "count(distinct col2)", @@ -392,7 +392,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -471,7 +471,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateSumDistinct, Col: 1, Alias: "sum(distinct col2)", @@ -480,7 +480,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { Opcode: AggregateSum, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -520,12 +520,12 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateSumDistinct, Col: 1, Alias: "sum(distinct col2)", }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -556,11 +556,11 @@ func TestOrderedAggregateKeysFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -589,11 +589,11 @@ func TestOrderedAggregateMergeFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -629,7 +629,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { func TestMerge(t *testing.T) { assert := assert.New(t) oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }, { @@ -716,12 +716,12 @@ func TestNoInputAndNoGroupingKeys(outer *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: test.opcode, Col: 0, Alias: test.name, }}, - GroupByKeys: []GroupByParams{}, + GroupByKeys: []*GroupByParams{}, Input: fp, } @@ -769,7 +769,7 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateGtid, Col: 1, Alias: "vgtid", @@ -790,3 +790,192 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { ) assert.Equal(t, wantResult, result) } + +func TestCountDistinctOnVarchar(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|b|0x42", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateCountDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "count(distinct c2)", + }}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|count(distinct c2)", + "int64|int64", + ), + `10|2`, + `20|1`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} + +func TestCountDistinctOnVarcharWithNulls(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "null|null|null", + "null|a|0x41", + "null|b|0x42", + "10|null|null", + "10|null|null", + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|null|null", + "20|b|0x42", + "30|null|null", + "30|null|null", + "30|null|null", + "30|null|null", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateCountDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "count(distinct c2)", + }}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|count(distinct c2)", + "int64|int64", + ), + `null|2`, + `10|2`, + `20|1`, + `30|0`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} + +func TestSumDistinctOnVarcharWithNulls(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "null|null|null", + "null|a|0x41", + "null|b|0x42", + "10|null|null", + "10|null|null", + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|null|null", + "20|b|0x42", + "30|null|null", + "30|null|null", + "30|null|null", + "30|null|null", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateSumDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "sum(distinct c2)", + }}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|sum(distinct c2)", + "int64|decimal", + ), + `null|0`, + `10|0`, + `20|0`, + `30|null`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} diff --git a/go/vt/vtgate/planbuilder/abstract/queryprojection.go b/go/vt/vtgate/planbuilder/abstract/queryprojection.go index 412f4fad69e..9311645a120 100644 --- a/go/vt/vtgate/planbuilder/abstract/queryprojection.go +++ b/go/vt/vtgate/planbuilder/abstract/queryprojection.go @@ -54,6 +54,10 @@ type ( GroupBy struct { Inner sqlparser.Expr WeightStrExpr sqlparser.Expr + + // This is to add the distinct function expression in grouping column for pushing down but not be to used as grouping key at VTGate level. + // Starts with 1 so that default (0) means unassigned. + DistinctAggrIndex int } ) @@ -69,7 +73,8 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) { return nil, semantics.Gen4NotSupportedF("%T in select list", selExp) } - if err := checkForInvalidAggregations(exp); err != nil { + err := checkForInvalidAggregations(exp) + if err != nil { return nil, err } col := SelectExpr{ @@ -137,9 +142,6 @@ func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { if len(fExpr.Exprs) != 1 { return false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr)) } - if fExpr.Distinct { - return false, semantics.Gen4NotSupportedF("distinct aggregation") - } } return true, nil }, exp.Expr) diff --git a/go/vt/vtgate/planbuilder/grouping.go b/go/vt/vtgate/planbuilder/grouping.go index 1f17aa4ec54..1bf8ee06f62 100644 --- a/go/vt/vtgate/planbuilder/grouping.go +++ b/go/vt/vtgate/planbuilder/grouping.go @@ -77,7 +77,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou default: return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: only simple references allowed") } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: colNumber, WeightStringCol: -1}) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: colNumber, WeightStringCol: -1}) } // Append the distinct aggregate if any. if node.extraDistinct != nil { @@ -110,7 +110,7 @@ func planDistinct(input logicalPlan) (logicalPlan, error) { if rc.column.Origin() == node { return newDistinct(node), nil } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: i, WeightStringCol: -1}) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: i, WeightStringCol: -1}) } newInput, err := planDistinct(node.input) if err != nil { diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 4afb4be4ad6..3aa99fc2f65 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -153,22 +153,42 @@ func (hp *horizonPlanning) planAggregations() error { } for _, e := range hp.qp.SelectExprs { - offset, _, err := pushProjection(e.Col, hp.plan, hp.semTable, true, false) + // push all expression if they are non-aggregating or the plan is not ordered aggregated plan. + if !e.Aggr || oa == nil { + _, _, err := pushProjection(e.Col, hp.plan, hp.semTable, true, false) + if err != nil { + return err + } + continue + } + + fExpr, isFunc := e.Col.Expr.(*sqlparser.FuncExpr) + if !isFunc { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") + } + opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] + handleDistinct, innerAliased, err := hp.needDistinctHandling(fExpr, opcode, oa.input) if err != nil { return err } - if e.Aggr && oa != nil { - fExpr, isFunc := e.Col.Expr.(*sqlparser.FuncExpr) - if !isFunc { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") - } - opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ - Opcode: opcode, - Col: offset, - Expr: fExpr, - }) + + // Currently the OA engine primitive is able to handle only one distinct aggregation function. + // PreProcess being true tells that it is already handling it. + if oa.eaggr.PreProcess && handleDistinct { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "multiple distinct aggregation function") } + + pushExpr, alias, opcode := hp.createPushExprAndAlias(e, handleDistinct, innerAliased, opcode, oa) + offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) + if err != nil { + return err + } + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ + Opcode: opcode, + Col: offset, + Alias: alias, + Expr: fExpr, + }) } for _, groupExpr := range hp.qp.GroupByExprs { @@ -214,6 +234,44 @@ func (hp *horizonPlanning) planAggregations() error { return nil } +// createPushExprAndAlias creates the expression that should be pushed down to the leaves, +// and changes the opcode so it is a distinct one if needed +func (hp *horizonPlanning) createPushExprAndAlias( + expr abstract.SelectExpr, + handleDistinct bool, + innerAliased *sqlparser.AliasedExpr, + opcode engine.AggregateOpcode, + oa *orderedAggregate, +) (*sqlparser.AliasedExpr, string, engine.AggregateOpcode) { + pushExpr := expr.Col + var alias string + if handleDistinct { + pushExpr = innerAliased + + switch opcode { + case engine.AggregateCount: + opcode = engine.AggregateCountDistinct + case engine.AggregateSum: + opcode = engine.AggregateSumDistinct + } + if expr.Col.As.IsEmpty() { + alias = sqlparser.String(expr.Col.Expr) + } else { + alias = expr.Col.As.String() + } + + oa.eaggr.PreProcess = true + hp.haveToTruncate(true) + by := abstract.GroupBy{ + Inner: innerAliased.Expr, + WeightStrExpr: innerAliased.Expr, + DistinctAggrIndex: len(oa.eaggr.Aggregates) + 1, + } + hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, by) + } + return pushExpr, alias, opcode +} + func hasUniqueVindex(vschema ContextVSchema, semTable *semantics.SemTable, groupByExprs []abstract.GroupBy) bool { for _, groupByExpr := range groupByExprs { if exprHasUniqueVindex(vschema, semTable, groupByExpr.WeightStrExpr) { @@ -233,11 +291,18 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem _, _, added, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node, semTable) return added, err case *orderedAggregate: - keyCol, weightStringOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) + keyCol, wsOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) if err != nil { return false, err } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: weightStringOffset, Expr: groupExpr.WeightStrExpr}) + if groupExpr.DistinctAggrIndex == 0 { + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: keyCol, WeightStringCol: wsOffset, Expr: groupExpr.WeightStrExpr}) + } else { + if wsOffset != -1 { + node.eaggr.Aggregates[groupExpr.DistinctAggrIndex-1].WAssigned = true + node.eaggr.Aggregates[groupExpr.DistinctAggrIndex-1].WCol = wsOffset + } + } colAddedRecursively, err := planGroupByGen4(groupExpr, node.input, semTable) if err != nil { return false, err @@ -335,6 +400,9 @@ func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan log if err != nil { return 0, 0, false, err } + if weightStrExpr == nil { + return offset, -1, added, nil + } _, ok := expr.(*sqlparser.ColName) if !ok { return 0, 0, false, semantics.Gen4NotSupportedF("group by/order by non-column expression") @@ -508,7 +576,7 @@ func (hp *horizonPlanning) planDistinctOA(currPlan *orderedAggregate) error { for _, aggrParam := range currPlan.eaggr.Aggregates { if sqlparser.EqualsExpr(sExpr.Col.Expr, aggrParam.Expr) { found = true - eaggr.GroupByKeys = append(eaggr.GroupByKeys, engine.GroupByParams{KeyCol: aggrParam.Col, WeightStringCol: -1}) + eaggr.GroupByKeys = append(eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: aggrParam.Col, WeightStringCol: -1}) break } } @@ -531,7 +599,7 @@ func (hp *horizonPlanning) addDistinct() error { eaggr: eaggr, } for index, sExpr := range hp.qp.SelectExprs { - grpParam := engine.GroupByParams{KeyCol: index, WeightStringCol: -1} + grpParam := &engine.GroupByParams{KeyCol: index, WeightStringCol: -1} _, wOffset, added, err := wrapAndPushExpr(sExpr.Col.Expr, sExpr.Col.Expr, hp.plan, hp.semTable) if err != nil { return err @@ -552,3 +620,30 @@ func selectHasUniqueVindex(vschema ContextVSchema, semTable *semantics.SemTable, } return false } + +// needDistinctHandling returns true if oa needs to handle the distinct clause. +// If true, it will also return the aliased expression that needs to be pushed +// down into the underlying route. +func (hp *horizonPlanning) needDistinctHandling(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode, input logicalPlan) (bool, *sqlparser.AliasedExpr, error) { + if !funcExpr.Distinct { + return false, nil, nil + } + if opcode != engine.AggregateCount && opcode != engine.AggregateSum { + return false, nil, nil + } + innerAliased, ok := funcExpr.Exprs[0].(*sqlparser.AliasedExpr) + if !ok { + return false, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "syntax error: %s", sqlparser.String(funcExpr)) + } + _, ok = input.(*route) + if !ok { + // Unreachable + return true, innerAliased, nil + } + if exprHasUniqueVindex(hp.vschema, hp.semTable, innerAliased.Expr) { + // if we can see a unique vindex on this table/column, + // we know the results will be unique, and we don't need to DISTINCTify them + return false, nil, nil + } + return true, innerAliased, nil +} diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index cf4bbed2e42..63b6cd33871 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -253,7 +253,7 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias case engine.AggregateSum: opcode = engine.AggregateSumDistinct } - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ Opcode: opcode, Col: innerCol, Alias: alias, @@ -264,7 +264,7 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias return nil, 0, err } pb.plan = newBuilder - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ Opcode: opcode, Col: innerCol, }) diff --git a/go/vt/vtgate/planbuilder/show.go b/go/vt/vtgate/planbuilder/show.go index 9608ce3c6fc..c0581df05ca 100644 --- a/go/vt/vtgate/planbuilder/show.go +++ b/go/vt/vtgate/planbuilder/show.go @@ -505,7 +505,7 @@ func buildShowVGtidPlan(show *sqlparser.ShowBasic, vschema ContextVSchema) (engi } return &engine.OrderedAggregate{ PreProcess: true, - Aggregates: []engine.AggregateParams{ + Aggregates: []*engine.AggregateParams{ { Opcode: engine.AggregateGtid, Col: 1, diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index 544c0bfaa5e..e226ded6152 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -866,6 +866,7 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont "Table": "`user`" } } +Gen4 plan same as above # count with distinct unique vindex "select col, count(distinct id) from user group by col" @@ -894,6 +895,31 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col, count(distinct id) from user group by col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(1)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col, count(distinct id), weight_string(col) from `user` where 1 != 1 group by col", + "OrderBy": "(0|2) ASC", + "Query": "select col, count(distinct id), weight_string(col) from `user` group by col order by col asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex "select col1, count(distinct col2) from user group by col1" @@ -922,6 +948,31 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS count(distinct col2)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex and no group by "select count(distinct col2) from user" @@ -949,6 +1000,30 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select count(distinct col2) from user", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(0|1) AS count(distinct col2)", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col2, weight_string(col2) from `user` where 1 != 1 group by col2", + "OrderBy": "(0|1) ASC", + "Query": "select col2, weight_string(col2) from `user` group by col2 order by col2 asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex, count expression aliased "select col1, count(distinct col2) c2 from user group by col1" @@ -977,6 +1052,31 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) c2 from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS c2", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # sum with distinct no unique vindex "select col1, sum(distinct col2) from user group by col1" @@ -1005,6 +1105,31 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, sum(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_distinct(1|3) AS sum(distinct col2)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # min with distinct no unique vindex. distinct is ignored. "select col1, min(distinct col2) from user group by col1" @@ -1033,6 +1158,31 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, min(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "min(1)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, min(distinct col2), weight_string(col1) from `user` where 1 != 1 group by col1", + "OrderBy": "(0|2) ASC", + "Query": "select col1, min(distinct col2), weight_string(col1) from `user` group by col1 order by col1 asc", + "Table": "`user`" + } + ] + } +} # order by count distinct "select col1, count(distinct col2) k from user group by col1 order by k" @@ -1068,6 +1218,38 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) k from user group by col1 order by k", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "1 ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS k", + "GroupBy": "(0|2)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } + ] + } +} # scatter aggregate group by aggregate function "select count(*) b from user group by b" diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 9eea5470aa0..545e2ae63f7 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -434,7 +434,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer sourceSelect := &sqlparser.Select{} targetSelect := &sqlparser.Select{} // aggregates contains the list if Aggregate functions, if any. - var aggregates []engine.AggregateParams + var aggregates []*engine.AggregateParams for _, selExpr := range sel.SelectExprs { switch selExpr := selExpr.(type) { case *sqlparser.StarExpr: @@ -463,7 +463,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer if expr, ok := selExpr.Expr.(*sqlparser.FuncExpr); ok { switch fname := expr.Name.Lowered(); fname { case "count", "sum": - aggregates = append(aggregates, engine.AggregateParams{ + aggregates = append(aggregates, &engine.AggregateParams{ Opcode: engine.SupportedAggregates[fname], Col: len(sourceSelect.SelectExprs) - 1, }) @@ -538,10 +538,10 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer return td, nil } -func pkColsToGroupByParams(pkCols []int) []engine.GroupByParams { - var res []engine.GroupByParams +func pkColsToGroupByParams(pkCols []int) []*engine.GroupByParams { + var res []*engine.GroupByParams for _, col := range pkCols { - res = append(res, engine.GroupByParams{KeyCol: col, WeightStringCol: -1}) + res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1}) } return res } diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index a0ed626a660..53736a04687 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -387,14 +387,14 @@ func TestVDiffPlanSuccess(t *testing.T) { pkCols: []int{0}, selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ - Aggregates: []engine.AggregateParams{{ + Aggregates: []*engine.AggregateParams{{ Opcode: engine.AggregateCount, Col: 2, }, { Opcode: engine.AggregateSum, Col: 3, }}, - GroupByKeys: []engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1}}, + GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1}}, Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), }, targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, 0, true}}),