Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gen4: Support multiple distinct aggregation functions in the query #8559

Merged
merged 3 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func TestDistinctAggregationFunc(t *testing.T) {
// sum on any column
assertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`,
`[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`)

// insert more data to get values on sum
checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (9, 'AA', null),(10, 'AA', '4'),(11, 'AA', '4'),(12, null, '5'),(13, null, '6'),(14, 'BB', '10'),(15, 'BB', '20'),(16, 'BB', 'X')`)

// multi distinct
assertMatches(t, conn, `select tcol1, count(distinct tcol2), sum(distinct tcol2) from t2 group by tcol1`,
`[[NULL INT64(2) DECIMAL(11)] [VARCHAR("A") INT64(2) DECIMAL(0)] [VARCHAR("AA") INT64(1) DECIMAL(4)] [VARCHAR("B") INT64(2) DECIMAL(0)] [VARCHAR("BB") INT64(3) DECIMAL(30)] [VARCHAR("C") INT64(1) DECIMAL(0)]]`)
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
Expand Down
47 changes: 24 additions & 23 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
}
// This code is similar to the one in StreamExecute.
var current []sqltypes.Value
var curDistinct sqltypes.Value
var curDistincts []sqltypes.Value
for _, row := range result.Rows {
if current == nil {
current, curDistinct = oa.convertRow(row)
current, curDistincts = oa.convertRow(row)
continue
}

Expand All @@ -218,14 +218,14 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
}

if equal {
current, curDistinct, err = oa.merge(result.Fields, current, row, curDistinct)
current, curDistincts, err = oa.merge(result.Fields, current, row, curDistincts)
if err != nil {
return nil, err
}
continue
}
out.Rows = append(out.Rows, current)
current, curDistinct = oa.convertRow(row)
current, curDistincts = oa.convertRow(row)
}

if len(result.Rows) == 0 && len(oa.GroupByKeys) == 0 {
Expand All @@ -251,7 +251,7 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
// StreamExecute is a Primitive function.
func (oa *OrderedAggregate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var current []sqltypes.Value
var curDistinct sqltypes.Value
var curDistincts []sqltypes.Value
var fields []*querypb.Field

cb := func(qr *sqltypes.Result) error {
Expand All @@ -268,7 +268,7 @@ func (oa *OrderedAggregate) StreamExecute(vcursor VCursor, bindVars map[string]*
// This code is similar to the one in Execute.
for _, row := range qr.Rows {
if current == nil {
current, curDistinct = oa.convertRow(row)
current, curDistincts = oa.convertRow(row)
continue
}

Expand All @@ -278,7 +278,7 @@ func (oa *OrderedAggregate) StreamExecute(vcursor VCursor, bindVars map[string]*
}

if equal {
current, curDistinct, err = oa.merge(fields, current, row, curDistinct)
current, curDistincts, err = oa.merge(fields, current, row, curDistincts)
if err != nil {
return err
}
Expand All @@ -287,7 +287,7 @@ func (oa *OrderedAggregate) StreamExecute(vcursor VCursor, bindVars map[string]*
if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil {
return err
}
current, curDistinct = oa.convertRow(row)
current, curDistincts = oa.convertRow(row)
}
return nil
})
Expand Down Expand Up @@ -322,23 +322,24 @@ func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) []*querypb.Fi
return fields
}

func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.Value, curDistinct sqltypes.Value) {
func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) {
if !oa.PreProcess {
return row, sqltypes.NULL
return row, nil
}
newRow = append(newRow, row...)
for _, aggr := range oa.Aggregates {
curDistincts = make([]sqltypes.Value, len(oa.Aggregates))
for index, aggr := range oa.Aggregates {
switch aggr.Opcode {
case AggregateCountDistinct:
curDistinct = findComparableCurrentDistinct(row, aggr)
curDistincts[index] = findComparableCurrentDistinct(row, aggr)
// Type is int64. Ok to call MakeTrusted.
if row[aggr.KeyCol].IsNull() {
newRow[aggr.Col] = countZero
} else {
newRow[aggr.Col] = countOne
}
case AggregateSumDistinct:
curDistinct = findComparableCurrentDistinct(row, aggr)
curDistincts[index] = findComparableCurrentDistinct(row, aggr)
var err error
newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode])
if err != nil {
Expand All @@ -356,7 +357,7 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.
newRow[aggr.Col] = val
}
}
return newRow, curDistinct
return newRow, curDistincts
}

func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value {
Expand Down Expand Up @@ -409,21 +410,21 @@ func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error)
return true, nil
}

func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes.Value, curDistinct sqltypes.Value) ([]sqltypes.Value, sqltypes.Value, error) {
func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes.Value, curDistincts []sqltypes.Value) ([]sqltypes.Value, []sqltypes.Value, error) {
result := sqltypes.CopyRow(row1)
for _, aggr := range oa.Aggregates {
for index, aggr := range oa.Aggregates {
if aggr.isDistinct() {
if row2[aggr.KeyCol].IsNull() {
continue
}
cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.KeyCol])
cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol])
if err != nil {
return nil, sqltypes.NULL, err
return nil, nil, err
}
if cmp == 0 {
continue
}
curDistinct = findComparableCurrentDistinct(row2, aggr)
curDistincts[index] = findComparableCurrentDistinct(row2, aggr)
}
var err error
switch aggr.Opcode {
Expand All @@ -443,7 +444,7 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes
vgtid := &binlogdatapb.VGtid{}
err = proto.Unmarshal(row1[aggr.Col].ToBytes(), vgtid)
if err != nil {
return nil, sqltypes.NULL, err
return nil, nil, err
}
vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{
Keyspace: row2[aggr.Col-1].ToString(),
Expand All @@ -454,13 +455,13 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes
val, _ := sqltypes.NewValue(sqltypes.VarBinary, data)
result[aggr.Col] = val
default:
return nil, sqltypes.NULL, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode)
return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode)
}
if err != nil {
return nil, sqltypes.NULL, err
return nil, nil, err
}
}
return result, curDistinct, nil
return result, curDistincts, nil
}

// creates the empty row for the case when we are missing grouping keys and have empty input table
Expand Down
75 changes: 73 additions & 2 deletions go/vt/vtgate/engine/ordered_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ func TestMerge(t *testing.T) {
"1|3|2.8|2|bc",
)

merged, _, err := oa.merge(fields, r.Rows[0], r.Rows[1], sqltypes.NULL)
merged, _, err := oa.merge(fields, r.Rows[0], r.Rows[1], nil)
assert.NoError(err)
want := sqltypes.MakeTestResult(fields, "1|5|6|2|bc").Rows[0]
assert.Equal(want, merged)

// swap and retry
merged, _, err = oa.merge(fields, r.Rows[1], r.Rows[0], sqltypes.NULL)
merged, _, err = oa.merge(fields, r.Rows[1], r.Rows[0], nil)
assert.NoError(err)
assert.Equal(want, merged)
}
Expand Down Expand Up @@ -979,3 +979,74 @@ func TestSumDistinctOnVarcharWithNulls(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, want, results)
}

func TestMultiDistinct(t *testing.T) {
fields := sqltypes.MakeTestFields(
"c1|c2|c3",
"int64|int64|int64",
)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"null|null|null",
"null|1|2",
"null|2|2",
"10|null|null",
"10|2|null",
"10|2|1",
"10|2|3",
"10|3|3",
"20|null|null",
"20|null|null",
"30|1|1",
"30|1|2",
"30|1|3",
"40|1|1",
"40|2|1",
"40|3|1",
)},
}

oa := &OrderedAggregate{
PreProcess: true,
Aggregates: []*AggregateParams{{
Opcode: AggregateCountDistinct,
Col: 1,
Alias: "count(distinct c2)",
}, {
Opcode: AggregateSumDistinct,
Col: 2,
Alias: "sum(distinct c3)",
}},
GroupByKeys: []*GroupByParams{{KeyCol: 0}},
Input: fp,
}

want := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"c1|count(distinct c2)|sum(distinct c3)",
"int64|int64|decimal",
),
`null|2|2`,
`10|2|4`,
`20|0|null`,
`30|1|6`,
`40|3|1`,
)

qr, err := oa.Execute(nil, nil, false)
require.NoError(t, err)
assert.Equal(t, want, qr)

fp.rewind()
results := &sqltypes.Result{}
err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
assert.Equal(t, want, results)
}
8 changes: 1 addition & 7 deletions go/vt/vtgate/planbuilder/horizon_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,8 @@ func (hp *horizonPlanning) planAggregations() error {
return err
}

// Currently the OA engine primitive is able to handle only one distinct aggregation function.
// PreProcess being true tells that it is already handling it.
if oa.eaggr.PreProcess && handleDistinct {
return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "multiple distinct aggregation function")
}

pushExpr, alias, opcode := hp.createPushExprAndAlias(e, handleDistinct, innerAliased, opcode, oa)
offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true)
offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, false)
if err != nil {
return err
}
Expand Down
57 changes: 57 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2252,3 +2252,60 @@ Gen4 plan same as above
]
}
}

# Cannot have more than one aggr(distinct...
"select count(distinct a), count(distinct b) from user"
"unsupported: only one distinct aggregation allowed in a select: count(distinct b)"
{
"QueryType": "SELECT",
"Original": "select count(distinct a), count(distinct b) from user",
"Instructions": {
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "count_distinct(0|2) AS count(distinct a), count_distinct(1|3) AS count(distinct b)",
"ResultColumns": 2,
"Inputs": [
{
"OperatorType": "Route",
"Variant": "SelectScatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select a, b, weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, b",
"OrderBy": "(0|2) ASC, (1|3) ASC",
"Query": "select a, b, weight_string(a), weight_string(b) from `user` group by a, b order by a asc, b asc",
"Table": "`user`"
}
]
}
}

# multiple distinct functions with grouping.
"select col1, count(distinct col2), sum(distinct col2) from user group by col1"
"unsupported: only one distinct aggregation allowed in a select: sum(distinct col2)"
{
"QueryType": "SELECT",
"Original": "select col1, count(distinct col2), sum(distinct col2) from user group by col1",
"Instructions": {
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "count_distinct(1|4) AS count(distinct col2), sum_distinct(2|4) AS sum(distinct col2)",
"GroupBy": "(0|3)",
"ResultColumns": 3,
"Inputs": [
{
"OperatorType": "Route",
"Variant": "SelectScatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, col2",
"OrderBy": "(0|3) ASC, (1|4) ASC, (1|4) ASC",
"Query": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, col2 order by col1 asc, col2 asc, col2 asc",
"Table": "`user`"
}
]
}
}
4 changes: 0 additions & 4 deletions go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ Gen4 plan same as above
"unsupported: only one expression allowed inside aggregates: count(a, b)"
Gen4 error: aggregate functions take a single argument 'count(a, b)'

# Cannot have more than one aggr(distinct...
"select count(distinct a), count(distinct b) from user"
"unsupported: only one distinct aggregation allowed in a select: count(distinct b)"

# scatter aggregate symtab lookup error
"select id, b as id, count(*) from user order by id"
"ambiguous symbol reference: id"
Expand Down