Skip to content

Commit

Permalink
sql: add the string_agg aggregation function
Browse files Browse the repository at this point in the history
This function is similar to concat_agg but it takes a delimiter as a secondary
argument. Previously, we were not able to handle aggregations with more than
one argument before. To allow for this, without getting into the messy world of
multi-column aggregators, all arguments after the first one in an aggregator
must be constant expressions.

This in turn required updating the aggregator functions to now also take
argument datums as an new argument.

For distsql, the arguments are stored as expressions that have already been
checked to ensure that they are constants.

It looks like concat_agg (and now string_agg) are not run in distsql yet, so
this will be added next.

This work is primarily motivated by the need for greater ORM compatibility.

Closes cockroachdb#10495, cockroachdb#26737

Release note (sql change): Added the new aggregation function string_agg that
concats a collection of strings into a single string and seperates them with the
passed in delimiter.
  • Loading branch information
BramGruneir committed Aug 8, 2018
1 parent 23a6c6d commit 83f08d5
Show file tree
Hide file tree
Showing 14 changed files with 718 additions and 343 deletions.
4 changes: 4 additions & 0 deletions docs/generated/sql/aggregates.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@
</span></td></tr>
<tr><td><code>stddev(arg1: <a href="int.html">int</a>) &rarr; <a href="decimal.html">decimal</a></code></td><td><span class="funcdesc"><p>Calculates the standard deviation of the selected values.</p>
</span></td></tr>
<tr><td><code>string_agg(arg1: <a href="bytes.html">bytes</a>, arg2: <a href="bytes.html">bytes</a>) &rarr; <a href="bytes.html">bytes</a></code></td><td><span class="funcdesc"><p>Concatenates all selected values using the provided delimiter.</p>
</span></td></tr>
<tr><td><code>string_agg(arg1: <a href="string.html">string</a>, arg2: <a href="string.html">string</a>) &rarr; <a href="string.html">string</a></code></td><td><span class="funcdesc"><p>Concatenates all selected values using the provided delimiter.</p>
</span></td></tr>
<tr><td><code>sum(arg1: <a href="decimal.html">decimal</a>) &rarr; <a href="decimal.html">decimal</a></code></td><td><span class="funcdesc"><p>Calculates the sum of the selected values.</p>
</span></td></tr>
<tr><td><code>sum(arg1: <a href="float.html">float</a>) &rarr; <a href="float.html">float</a></code></td><td><span class="funcdesc"><p>Calculates the sum of the selected values.</p>
Expand Down
23 changes: 20 additions & 3 deletions pkg/sql/distsql_physical_planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,7 @@ func (dsp *DistSQLPlanner) addAggregators(
planCtx *planningCtx, p *physicalPlan, n *groupNode,
) error {
aggregations := make([]distsqlrun.AggregatorSpec_Aggregation, len(n.funcs))
aggregationsColumnTypes := make([][]sqlbase.ColumnType, len(n.funcs))
for i, fholder := range n.funcs {
// Convert the aggregate function to the enum value with the same string
// representation.
Expand All @@ -1184,6 +1185,19 @@ func (dsp *DistSQLPlanner) addAggregators(
col := uint32(p.planToStreamColMap[fholder.filterRenderIdx])
aggregations[i].FilterColIdx = &col
}
aggregations[i].Arguments = make([]distsqlrun.Expression, len(fholder.arguments))
aggregationsColumnTypes[i] = make([]sqlbase.ColumnType, len(fholder.arguments))
for j, argument := range fholder.arguments {
var err error
aggregations[i].Arguments[j], err = distsqlplan.MakeExpression(argument, planCtx.EvalContext(), nil)
if err != nil {
return err
}
aggregationsColumnTypes[i][j], err = sqlbase.DatumTypeToColumnType(argument.ResolvedType())
if err != nil {
return err
}
}
}

aggType := distsqlrun.AggregatorSpec_NON_SCALAR
Expand Down Expand Up @@ -1612,9 +1626,12 @@ func (dsp *DistSQLPlanner) addAggregators(

finalOutTypes := make([]sqlbase.ColumnType, len(aggregations))
for i, agg := range aggregations {
argTypes := make([]sqlbase.ColumnType, len(agg.ColIdx))
for i, c := range agg.ColIdx {
argTypes[i] = inputTypes[c]
argTypes := make([]sqlbase.ColumnType, len(agg.ColIdx)+len(agg.Arguments))
for j, c := range agg.ColIdx {
argTypes[j] = inputTypes[c]
}
for j, argumentColumnType := range aggregationsColumnTypes[i] {
argTypes[len(agg.ColIdx)+j] = argumentColumnType
}
var err error
_, finalOutTypes[i], err = distsqlrun.GetAggregateInfo(agg.Func, argTypes...)
Expand Down
58 changes: 41 additions & 17 deletions pkg/sql/distsqlrun/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"fmt"

"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/types"
Expand All @@ -39,7 +40,7 @@ import (
func GetAggregateInfo(
fn AggregatorSpec_Func, inputTypes ...sqlbase.ColumnType,
) (
aggregateConstructor func(*tree.EvalContext) tree.AggregateFunc,
aggregateConstructor func(*tree.EvalContext, tree.Datums) tree.AggregateFunc,
returnType sqlbase.ColumnType,
err error,
) {
Expand Down Expand Up @@ -71,8 +72,8 @@ func GetAggregateInfo(
}
if match {
// Found!
constructAgg := func(evalCtx *tree.EvalContext) tree.AggregateFunc {
return b.AggregateFunc(datumTypes, evalCtx)
constructAgg := func(evalCtx *tree.EvalContext, arguments tree.Datums) tree.AggregateFunc {
return b.AggregateFunc(datumTypes, evalCtx, arguments)
}

colTyp, err := sqlbase.DatumTypeToColumnType(b.FixedReturnType())
Expand All @@ -83,7 +84,7 @@ func GetAggregateInfo(
}
}
return nil, sqlbase.ColumnType{}, errors.Errorf(
"no builtin aggregate for %s on %v", fn, inputTypes,
"no builtin aggregate for %s on %+v", fn, inputTypes,
)
}

Expand Down Expand Up @@ -194,19 +195,40 @@ func (ag *aggregatorBase) init(
)
}
}
argTypes := make([]sqlbase.ColumnType, len(aggInfo.ColIdx))
for i, c := range aggInfo.ColIdx {
argTypes := make([]sqlbase.ColumnType, len(aggInfo.ColIdx)+len(aggInfo.Arguments))
for j, c := range aggInfo.ColIdx {
if c >= uint32(len(ag.inputTypes)) {
return errors.Errorf("ColIdx out of range (%d)", aggInfo.ColIdx)
}
argTypes[i] = ag.inputTypes[c]
argTypes[j] = ag.inputTypes[c]
}

arguments := make(tree.Datums, len(aggInfo.Arguments))
for j, argument := range aggInfo.Arguments {
expr, err := parser.ParseExpr(argument.Expr)
if err != nil {
return err
}
typedExpr, err := tree.TypeCheck(expr, &tree.SemaContext{}, types.Any)
if err != nil {
return errors.Wrap(err, expr.String())
}
argTypes[len(aggInfo.ColIdx)+j], err = sqlbase.DatumTypeToColumnType(typedExpr.ResolvedType())
if err != nil {
return errors.Wrap(err, expr.String())
}
arguments[j], err = typedExpr.Eval(ag.evalCtx)
if err != nil {
return errors.Wrap(err, expr.String())
}
}

aggConstructor, retType, err := GetAggregateInfo(aggInfo.Func, argTypes...)
if err != nil {
return err
}

ag.funcs[i] = ag.newAggregateFuncHolder(aggConstructor)
ag.funcs[i] = ag.newAggregateFuncHolder(aggConstructor, arguments)
if aggInfo.Distinct {
ag.funcs[i].seen = make(map[string]struct{})
}
Expand Down Expand Up @@ -815,21 +837,23 @@ func (ag *orderedAggregator) accumulateRow(row sqlbase.EncDatumRow) error {
}

type aggregateFuncHolder struct {
create func(*tree.EvalContext) tree.AggregateFunc
group *aggregatorBase
seen map[string]struct{}
arena *stringarena.Arena
create func(*tree.EvalContext, tree.Datums) tree.AggregateFunc
arguments tree.Datums
group *aggregatorBase
seen map[string]struct{}
arena *stringarena.Arena
}

const sizeOfAggregateFunc = int64(unsafe.Sizeof(tree.AggregateFunc(nil)))

func (ag *aggregatorBase) newAggregateFuncHolder(
create func(*tree.EvalContext) tree.AggregateFunc,
create func(*tree.EvalContext, tree.Datums) tree.AggregateFunc, arguments tree.Datums,
) *aggregateFuncHolder {
return &aggregateFuncHolder{
create: create,
group: ag,
arena: &ag.arena,
create: create,
group: ag,
arena: &ag.arena,
arguments: arguments,
}
}

Expand Down Expand Up @@ -886,7 +910,7 @@ func (ag *aggregatorBase) createAggregateFuncs() (aggregateFuncs, error) {
for i, f := range ag.funcs {
// TODO(radu): we should account for the size of impl (this needs to be done
// in each aggregate constructor).
bucket[i] = f.create(&ag.flowCtx.EvalCtx)
bucket[i] = f.create(&ag.flowCtx.EvalCtx, f.arguments)
}
return bucket, nil
}
Loading

0 comments on commit 83f08d5

Please sign in to comment.