Skip to content

Commit

Permalink
Merge pull request #8543 from planetscale/gen4-func-distinct
Browse files Browse the repository at this point in the history
Gen4: Count Distinct support
  • Loading branch information
harshit-gangal authored Jul 28, 2021
2 parents f3c7a2c + eeb4fc2 commit 0df3bab
Show file tree
Hide file tree
Showing 13 changed files with 613 additions and 80 deletions.
12 changes: 12 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,18 @@ func (v Value) IsDateTime() bool {
return int(v.typ)&dt == dt
}

// IsComparable returns true if the Value is null safe comparable without collation information.
func (v *Value) IsComparable() bool {
if v.typ == Null || IsNumber(v.typ) || IsBinary(v.typ) {
return true
}
switch v.typ {
case Timestamp, Date, Time, Datetime, Enum, Set, TypeJSON, Bit:
return true
}
return false
}

// MarshalJSON should only be used for testing.
// It's not a complete implementation.
func (v Value) MarshalJSON() ([]byte, error) {
Expand Down
32 changes: 32 additions & 0 deletions go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,38 @@ func TestGroupBy(t *testing.T) {
`[INT64(2) VARCHAR("B") VARCHAR("C") VARCHAR("abc")]]`)
}

func TestDistinctAggregationFunc(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()

defer exec(t, conn, `delete from t2`)

// insert some data.
checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`)

// count on primary vindex
assertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`,
`[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`)

// count on any column
assertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`,
`[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`)

// sum of columns
assertMatches(t, conn, `select sum(id), sum(tcol1) from t2`,
`[[DECIMAL(36) FLOAT64(0)]]`)

// sum on primary vindex
assertMatches(t, conn, `select tcol1, sum(distinct id) from t2 group by tcol1`,
`[[VARCHAR("A") DECIMAL(9)] [VARCHAR("B") DECIMAL(15)] [VARCHAR("C") DECIMAL(12)]]`)

// sum on any column
assertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`,
`[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`)
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
t.Helper()
qr := checkedExec(t, conn, query)
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 38 additions & 17 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ type OrderedAggregate struct {
PreProcess bool `json:",omitempty"`
// Aggregates specifies the aggregation parameters for each
// aggregation function: function opcode and input column number.
Aggregates []AggregateParams
Aggregates []*AggregateParams

// GroupByKeys specifies the input values that must be used for
// the aggregation key.
GroupByKeys []GroupByParams
GroupByKeys []*GroupByParams

// TruncateColumnCount specifies the number of columns to return
// in the final result. Rest of the columns are truncated
Expand Down Expand Up @@ -80,25 +80,34 @@ func (gbp GroupByParams) String() string {
type AggregateParams struct {
Opcode AggregateOpcode
Col int

// These are used only for distinct opcodes.
KeyCol int
WCol int
WAssigned bool
// Alias is set only for distinct opcodes.
Alias string `json:",omitempty"`
Expr sqlparser.Expr

Expr sqlparser.Expr
}

func (ap AggregateParams) isDistinct() bool {
func (ap *AggregateParams) isDistinct() bool {
return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct
}

func (ap AggregateParams) preProcess() bool {
func (ap *AggregateParams) preProcess() bool {
return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct || ap.Opcode == AggregateGtid
}

func (ap AggregateParams) String() string {
func (ap *AggregateParams) String() string {
keyCol := strconv.Itoa(ap.Col)
if ap.WAssigned {
keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol)
}
if ap.Alias != "" {
return fmt.Sprintf("%s(%d) AS %s", ap.Opcode.String(), ap.Col, ap.Alias)
return fmt.Sprintf("%s(%s) AS %s", ap.Opcode.String(), keyCol, ap.Alias)
}

return fmt.Sprintf("%s(%d)", ap.Opcode.String(), ap.Col)
return fmt.Sprintf("%s(%s)", ap.Opcode.String(), keyCol)
}

// AggregateOpcode is the aggregation Opcode.
Expand Down Expand Up @@ -306,6 +315,9 @@ func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) []*querypb.Fi
Name: aggr.Alias,
Type: opcodeType[aggr.Opcode],
}
if aggr.isDistinct() {
aggr.KeyCol = aggr.Col
}
}
return fields
}
Expand All @@ -318,15 +330,15 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.
for _, aggr := range oa.Aggregates {
switch aggr.Opcode {
case AggregateCountDistinct:
curDistinct = row[aggr.Col]
curDistinct = findComparableCurrentDistinct(row, aggr)
// Type is int64. Ok to call MakeTrusted.
if row[aggr.Col].IsNull() {
if row[aggr.KeyCol].IsNull() {
newRow[aggr.Col] = countZero
} else {
newRow[aggr.Col] = countOne
}
case AggregateSumDistinct:
curDistinct = row[aggr.Col]
curDistinct = findComparableCurrentDistinct(row, aggr)
var err error
newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode])
if err != nil {
Expand All @@ -347,6 +359,15 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.
return newRow, curDistinct
}

func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value {
curDistinct := row[aggr.KeyCol]
if aggr.WAssigned && !curDistinct.IsComparable() {
aggr.KeyCol = aggr.WCol
curDistinct = row[aggr.KeyCol]
}
return curDistinct
}

// GetFields is a Primitive function.
func (oa *OrderedAggregate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
qr, err := oa.Input.GetFields(vcursor, bindVars)
Expand Down Expand Up @@ -392,17 +413,17 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes
result := sqltypes.CopyRow(row1)
for _, aggr := range oa.Aggregates {
if aggr.isDistinct() {
if row2[aggr.Col].IsNull() {
if row2[aggr.KeyCol].IsNull() {
continue
}
cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.Col])
cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.KeyCol])
if err != nil {
return nil, sqltypes.NULL, err
}
if cmp == 0 {
continue
}
curDistinct = row2[aggr.Col]
curDistinct = findComparableCurrentDistinct(row2, aggr)
}
var err error
switch aggr.Opcode {
Expand Down Expand Up @@ -473,11 +494,11 @@ func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) {
}

func aggregateParamsToString(in interface{}) string {
return in.(AggregateParams).String()
return in.(*AggregateParams).String()
}

func groupByParamsToString(i interface{}) string {
return i.(GroupByParams).String()
return i.(*GroupByParams).String()
}

func (oa *OrderedAggregate) description() PrimitiveDescription {
Expand Down
Loading

0 comments on commit 0df3bab

Please sign in to comment.