diff --git a/db/tests/query/inline_array/with_count_test.go b/db/tests/query/inline_array/with_count_test.go index 452efcc1e4..d658c63b73 100644 --- a/db/tests/query/inline_array/with_count_test.go +++ b/db/tests/query/inline_array/with_count_test.go @@ -71,7 +71,7 @@ func TestQueryInlineIntegerArrayWithsWithCountAndEmptyArray(t *testing.T) { func TestQueryInlineIntegerArrayWithsWithCountAndPopulatedArray(t *testing.T) { test := testUtils.QueryTestCase{ - Description: "Simple inline array with no filter, count of empty integer array", + Description: "Simple inline array with no filter, count of integer array", Query: `query { users { Name diff --git a/db/tests/query/inline_array/with_sum_test.go b/db/tests/query/inline_array/with_sum_test.go new file mode 100644 index 0000000000..0567405e2d --- /dev/null +++ b/db/tests/query/inline_array/with_sum_test.go @@ -0,0 +1,178 @@ +// Copyright 2020 Source Inc. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +package inline_array + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/db/tests" +) + +func TestQueryInlineIntegerArrayWithsWithSumAndNullArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of nil integer array", + Query: `query { + users { + Name + _sum(field: {FavouriteIntegers: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "FavouriteIntegers": null + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": int64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQueryInlineIntegerArrayWithsWithSumAndEmptyArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of empty integer array", + Query: `query { + users { + Name + _sum(field: {FavouriteIntegers: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "FavouriteIntegers": [] + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": int64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQueryInlineIntegerArrayWithsWithSumAndPopulatedArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of integer array", + Query: `query { + users { + Name + _sum(field: {FavouriteIntegers: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "Shahzad", + "FavouriteIntegers": [-1, 2, -1, 1, 0] + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "Shahzad", + "_sum": int64(1), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQueryInlineFloatArrayWithsWithSumAndNullArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of nil float array", + Query: `query { + users { + Name + _sum(field: {FavouriteFloats: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "FavouriteFloats": null + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": float64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQueryInlineFloatArrayWithsWithSumAndEmptyArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of empty float array", + Query: `query { + users { + Name + _sum(field: {FavouriteFloats: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "FavouriteFloats": [] + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": float64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQueryInlineFloatArrayWithsWithSumAndPopulatedArray(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple inline array with no filter, sum of float array", + Query: `query { + users { + Name + _sum(field: {FavouriteFloats: {}}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "FavouriteFloats": [3.1425, 0.00000000001, 10] + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": float64(13.14250000001), + }, + }, + } + + executeTestCase(t, test) +} diff --git a/db/tests/query/simple/utils.go b/db/tests/query/simple/utils.go index 34de538243..99015c9e9b 100644 --- a/db/tests/query/simple/utils.go +++ b/db/tests/query/simple/utils.go @@ -19,6 +19,7 @@ var userCollectionGQLSchema = (` type users { Name: String Age: Int + HeightM: Float Verified: Boolean } `) diff --git a/db/tests/query/simple/with_group_sum_test.go b/db/tests/query/simple/with_group_sum_test.go new file mode 100644 index 0000000000..4635bed634 --- /dev/null +++ b/db/tests/query/simple/with_group_sum_test.go @@ -0,0 +1,201 @@ +// Copyright 2020 Source Inc. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +package simple + +// TODO!!!!! once scalar are merged, this should be capable of summing int/float arrays - likely needs some tweaks in generator.go and query.go + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/db/tests" +) + +func TestQuerySimpleWithGroupByStringWithoutRenderedGroupAndSumOfUndefined(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query with sum on unspecified field", + Query: `query { + users (groupBy: [Name]) { + Name + _sum + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "Age": 32 + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": int64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithGroupByStringWithoutRenderedGroupAndChildIntegerSum(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query with group by string, sum on non-rendered group integer value", + Query: `query { + users(groupBy: [Name]) { + Name + _sum(field: {_group: Age}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "Age": 32 + }`), + (`{ + "Name": "John", + "Age": 38 + }`), + // It is important to test negative values here, due to the auto-typing of numbers + (`{ + "Name": "Alice", + "Age": -19 + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": int64(70), + }, + { + "Name": "Alice", + "_sum": int64(-19), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithGroupByStringWithoutRenderedGroupAndChildNilSum(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query with group by string, sum on non-rendered group nil and integer values", + Query: `query { + users(groupBy: [Name]) { + Name + _sum(field: {_group: Age}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "Age": 32 + }`), + // Age is undefined here + (`{ + "Name": "John" + }`), + (`{ + "Name": "Alice", + "Age": 19 + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "Alice", + "_sum": int64(19), + }, + { + "Name": "John", + "_sum": int64(32), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithGroupByStringWithoutRenderedGroupAndChildEmptyFloatSum(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query with group by string, sum on non-rendered group float (default) value", + Query: `query { + users(groupBy: [Name]) { + Name + _sum(field: {_group: HeightM}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "HeightM": 1.82 + }`), + (`{ + "Name": "John", + "HeightM": 1.89 + }`), + (`{ + "Name": "Alice" + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": float64(3.71), + }, + { + "Name": "Alice", + "_sum": float64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithGroupByStringWithoutRenderedGroupAndChildFloatSum(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query with group by string, sum on non-rendered group float value", + Query: `query { + users(groupBy: [Name]) { + Name + _sum(field: {_group: HeightM}) + } + }`, + Docs: map[int][]string{ + 0: { + (`{ + "Name": "John", + "HeightM": 1.82 + }`), + (`{ + "Name": "John", + "HeightM": 1.89 + }`), + (`{ + "Name": "Alice", + "HeightM": 2.04 + }`)}, + }, + Results: []map[string]interface{}{ + { + "Name": "John", + "_sum": float64(3.71), + }, + { + "Name": "Alice", + "_sum": float64(2.04), + }, + }, + } + + executeTestCase(t, test) +} diff --git a/query/graphql/parser/commit.go b/query/graphql/parser/commit.go index de35e50d88..bfafb23b9a 100644 --- a/query/graphql/parser/commit.go +++ b/query/graphql/parser/commit.go @@ -45,7 +45,6 @@ type CommitSelect struct { Limit *Limit OrderBy *OrderBy - Counts []Count Fields []Selection @@ -72,17 +71,12 @@ func (c CommitSelect) GetSelections() []Selection { return c.Fields } -func (s *CommitSelect) AddCount(count Count) { - s.Counts = append(s.Counts, count) -} - func (c CommitSelect) ToSelect() *Select { return &Select{ Name: c.Name, Alias: c.Alias, Limit: c.Limit, OrderBy: c.OrderBy, - Counts: c.Counts, Statement: c.Statement, Fields: c.Fields, Root: CommitSelection, @@ -126,7 +120,5 @@ func parseCommitSelect(field *ast.Field) (*CommitSelect, error) { var err error commit.Fields, err = parseSelectFields(commit.GetRoot(), field.SelectionSet) - parseCounts(commit) - return commit, err } diff --git a/query/graphql/parser/query.go b/query/graphql/parser/query.go index 54369479a6..7c01ebec43 100644 --- a/query/graphql/parser/query.go +++ b/query/graphql/parser/query.go @@ -28,6 +28,7 @@ const ( GroupFieldName = "_group" DocKeyFieldName = "_key" CountFieldName = "_count" + SumFieldName = "_sum" HiddenFieldName = "_hidden" ) @@ -41,10 +42,16 @@ var ReservedFields = map[string]bool{ VersionFieldName: true, GroupFieldName: true, CountFieldName: true, + SumFieldName: true, HiddenFieldName: true, DocKeyFieldName: true, } +var aggregates = map[string]struct{}{ + CountFieldName: {}, + SumFieldName: {}, +} + type Query struct { Queries []*OperationDefinition Mutations []*OperationDefinition @@ -78,11 +85,6 @@ type Selection interface { GetRoot() SelectionType } -type baseSelect interface { - Selection - AddCount(count Count) -} - // Select is a complex Field with strong typing // It used for sub types in a query. Includes // fields, and query arguments like filters, @@ -102,7 +104,6 @@ type Select struct { Limit *Limit OrderBy *OrderBy GroupBy *GroupBy - Counts []Count Fields []Selection @@ -130,10 +131,6 @@ func (s Select) GetAlias() string { return s.Alias } -func (s *Select) AddCount(count Count) { - s.Counts = append(s.Counts, count) -} - // Field implements Selection type Field struct { Name string @@ -170,11 +167,6 @@ type GroupBy struct { Fields []string } -type Count struct { - Name string - Field string -} - type SortDirection string const ( @@ -387,11 +379,6 @@ func parseSelect(rootType SelectionType, field *ast.Field) (*Select, error) { return nil, err } - err = parseCounts(slct) - if err != nil { - return nil, err - } - return slct, err } @@ -402,7 +389,10 @@ func parseSelectFields(root SelectionType, fields *ast.SelectionSet) ([]Selectio switch node := selection.(type) { case *ast.Field: if node.SelectionSet == nil { // regular field - f := parseField(root, node) + f, err := parseField(i, root, node) + if err != nil { + return nil, err + } selections[i] = f } else { // sub type with extra fields subroot := root @@ -424,16 +414,31 @@ func parseSelectFields(root SelectionType, fields *ast.SelectionSet) ([]Selectio // parseField simply parses the Name/Alias // into a Field type -func parseField(root SelectionType, field *ast.Field) *Field { +func parseField(i int, root SelectionType, field *ast.Field) (*Field, error) { + var name string + var alias string + + if _, isAggregate := aggregates[field.Name.Value]; isAggregate { + name = fmt.Sprintf("_agg%v", i) + if field.Alias == nil { + alias = field.Name.Value + } else { + alias = field.Alias.Value + } + } else { + name = field.Name.Value + if field.Alias != nil { + alias = field.Alias.Value + } + } + f := &Field{ Root: root, - Name: field.Name.Value, + Name: name, Statement: field, + Alias: alias, } - if field.Alias != nil { - f.Alias = field.Alias.Value - } - return f + return f, nil } func parseAPIQuery(field *ast.Field) (Selection, error) { @@ -445,31 +450,29 @@ func parseAPIQuery(field *ast.Field) (Selection, error) { } } -// Parses requested _count(s), creating a virtual name (alias) if an alias is not provided to allow for multiple _count -// fields. Strongly consider refactoring this as more aggregates get added. -func parseCounts(slct baseSelect) error { - for i, field := range slct.GetSelections() { - if field.GetName() == CountFieldName { - virtualName := fmt.Sprintf("count%v", i) - f := field.(*Field) - if f.Alias == "" { - f.Alias = f.Name - } - f.Name = virtualName - var fieldName string - fieldStatement, statementIsField := field.GetStatement().(*ast.Field) - if !statementIsField { - return fmt.Errorf("Unexpected error: could not cast field statement to field.") - } +// Returns the source of the aggregate as requested by the consumer +func (field Field) GetAggregateSource() ([]string, error) { + var path []string - if len(fieldStatement.Arguments) == 0 { - fieldName = "" + if len(field.Statement.Arguments) == 0 { + path = []string{} + } else { + switch arguementValue := field.Statement.Arguments[0].Value.GetValue().(type) { + case string: + path = []string{arguementValue} + case []*ast.ObjectField: + if len(arguementValue) == 0 { + return []string{}, fmt.Errorf("Unexpected error: aggregate field contained no child field selector") + } + innerPath := arguementValue[0].Value.GetValue() + if innerPathStringValue, isString := innerPath.(string); isString { + path = []string{arguementValue[0].Name.Value, innerPathStringValue} } else { - fieldName = fieldStatement.Arguments[0].Value.GetValue().(string) + // If the inner path is not a string, this must mean the field is an inline array in which case we only want the base path + path = []string{arguementValue[0].Name.Value} } - slct.AddCount(Count{Name: virtualName, Field: fieldName}) } } - return nil + return path, nil } diff --git a/query/graphql/planner/count.go b/query/graphql/planner/count.go index 133b2575b3..88e5e2b755 100644 --- a/query/graphql/planner/count.go +++ b/query/graphql/planner/count.go @@ -28,11 +28,23 @@ type countNode struct { virtualFieldId string } -func (p *Planner) Count(c *parser.Count, virtualFieldId string) (*countNode, error) { +func (p *Planner) Count(field *parser.Field) (*countNode, error) { + source, err := field.GetAggregateSource() + if err != nil { + return nil, err + } + + var sourceProperty string + if len(source) == 1 { + sourceProperty = source[0] + } else { + sourceProperty = "" + } + return &countNode{ p: p, - sourceProperty: c.Field, - virtualFieldId: virtualFieldId, + sourceProperty: sourceProperty, + virtualFieldId: field.Name, }, nil } @@ -67,3 +79,5 @@ func (n *countNode) Values() map[string]interface{} { func (n *countNode) Next() (bool, error) { return n.plan.Next() } + +func (n *countNode) SetPlan(p planNode) { n.plan = p } diff --git a/query/graphql/planner/planner.go b/query/graphql/planner/planner.go index 05b4383a8e..5cd9d45a25 100644 --- a/query/graphql/planner/planner.go +++ b/query/graphql/planner/planner.go @@ -203,11 +203,7 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec plan.plan = plan.group } - // consider extracting this out to an `expandAggregatePlan` when adding more aggregates - for _, countPlan := range plan.countPlans { - countPlan.plan = plan.plan - plan.plan = countPlan - } + p.expandAggregatePlans(plan) // if order if plan.sort != nil { @@ -228,6 +224,18 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec return nil } +type aggregateNode interface { + planNode + SetPlan(plan planNode) +} + +func (p *Planner) expandAggregatePlans(plan *selectTopNode) { + for _, aggregate := range plan.aggregates { + aggregate.SetPlan(plan.plan) + plan.plan = aggregate + } +} + func (p *Planner) expandMultiNode(plan MultiNode, parentPlan *selectTopNode) error { for _, child := range plan.Children() { if err := p.expandPlan(child, parentPlan); err != nil { @@ -297,7 +305,7 @@ func (p *Planner) expandLimitPlan(plan *selectTopNode, parentPlan *selectTopNode // if this is a child node, and the parent select has an aggregate then we need to // replace the hard limit with a render limit to allow the full set of child records // to be aggregated - if parentPlan != nil && len(parentPlan.countPlans) > 0 { + if parentPlan != nil && len(parentPlan.aggregates) > 0 { renderLimit, err := p.RenderLimit(&parser.Limit{ Offset: l.offset, Limit: l.limit, diff --git a/query/graphql/planner/select.go b/query/graphql/planner/select.go index c0645dbabc..76f07604d6 100644 --- a/query/graphql/planner/select.go +++ b/query/graphql/planner/select.go @@ -27,8 +27,8 @@ type selectTopNode struct { group *groupNode sort *sortNode limit planNode - countPlans []*countNode render *renderNode + aggregates []aggregateNode // top of the plan graph plan planNode @@ -130,13 +130,13 @@ func (n *selectNode) Close() error { // creating scanNodes, typeIndexJoinNodes, and splitting // the necessary filters. Its designed to work with the // planner.Select construction call. -func (n *selectNode) initSource(parsed *parser.Select) error { +func (n *selectNode) initSource(parsed *parser.Select) ([]aggregateNode, error) { if parsed.CollectionName == "" { parsed.CollectionName = parsed.Name } sourcePlan, err := n.p.getSource(parsed.CollectionName) if err != nil { - return err + return nil, err } n.source = sourcePlan.plan n.origSource = sourcePlan.plan @@ -170,7 +170,7 @@ func (n *selectNode) initSource(parsed *parser.Select) error { return n.initFields(parsed) } -func (n *selectNode) initFields(parsed *parser.Select) error { +func (n *selectNode) initFields(parsed *parser.Select) ([]aggregateNode, error) { // re-organize the fields slice into reverse-alphabetical // this makes sure the reserved database fields that start with // a "_" end up at the end. So if/when we build our MultiNode @@ -179,65 +179,98 @@ func (n *selectNode) initFields(parsed *parser.Select) error { return !(strings.Compare(parsed.Fields[i].GetName(), parsed.Fields[j].GetName()) < 0) }) + aggregates := []aggregateNode{} // loop over the sub type // at the moment, we're only testing a single sub selection for _, field := range parsed.Fields { - if subtype, ok := field.(*parser.Select); ok { + switch f := field.(type) { + case *parser.Select: // @todo: check select type: // - TypeJoin // - commitScan - if subtype.Name == "_version" { // reserved sub type for object queries + if f.Name == parser.VersionFieldName { // reserved sub type for object queries commitSlct := &parser.CommitSelect{ - Name: subtype.Name, - Alias: subtype.Alias, + Name: f.Name, + Alias: f.Alias, Type: parser.LatestCommits, - Fields: subtype.Fields, + Fields: f.Fields, } commitPlan, err := n.p.CommitSelect(commitSlct) if err != nil { - return err + return nil, err } if err := n.addSubPlan(field.GetName(), commitPlan); err != nil { - return err + return nil, err } - } else if subtype.Root == parser.ObjectSelection { - if subtype.Name == parser.GroupFieldName { - n.groupSelect = subtype + } else if f.Root == parser.ObjectSelection { + if f.Name == parser.GroupFieldName { + n.groupSelect = f } else { - n.addTypeIndexJoin(subtype) + n.addTypeIndexJoin(f) } } + case *parser.Field: + var plan aggregateNode + var aggregateError error + + switch f.Statement.Name.Value { + case parser.CountFieldName: + plan, aggregateError = n.p.Count(f) + case parser.SumFieldName: + plan, aggregateError = n.p.Sum(&n.sourceInfo, f) + default: + continue + } + + if aggregateError != nil { + return nil, aggregateError + } + + aggregates = append(aggregates, plan) + + aggregateError = n.joinAggregatedChild(parsed, f) + if aggregateError != nil { + return nil, aggregateError + } } } - // Handle aggregates of child collection that are not rendered - for _, count := range parsed.Counts { - if count.Field == "" { - continue - } + return aggregates, nil +} - hasChildProperty := false - for _, field := range parsed.Fields { - if count.Field == field.GetName() { - hasChildProperty = true - break - } +// Join any child collections required by the given transformation if the child collections have not been requested for render by the consumer +func (n *selectNode) joinAggregatedChild(parsed *parser.Select, field *parser.Field) error { + source, err := field.GetAggregateSource() + if err != nil { + return err + } + + if len(source) == 0 { + return nil + } + + fieldName := source[0] + hasChildProperty := false + for _, field := range parsed.Fields { + if fieldName == field.GetName() { + hasChildProperty = true + break } + } - // If the child item is not requested, then we have add in the necessary components to force the child records to be scanned through (they wont be rendered) - if !hasChildProperty { - if count.Field == parser.GroupFieldName { - // It doesn't really matter at the moment if multiple counts are requested and we overwrite the n.groupSelect property - n.groupSelect = &parser.Select{ - Name: parser.GroupFieldName, - } - } else if parsed.Root != parser.CommitSelection { - subtype := &parser.Select{ - Name: count.Field, - } - n.addTypeIndexJoin(subtype) + // If the child item is not requested, then we have add in the necessary components to force the child records to be scanned through (they wont be rendered) + if !hasChildProperty { + if fieldName == parser.GroupFieldName { + // It doesn't really matter at the moment if multiple counts are requested and we overwrite the n.groupSelect property + n.groupSelect = &parser.Select{ + Name: parser.GroupFieldName, + } + } else if parsed.Root != parser.CommitSelection { + subtype := &parser.Select{ + Name: fieldName, } + n.addTypeIndexJoin(subtype) } } @@ -307,7 +340,8 @@ func (p *Planner) SelectFromSource(parsed *parser.Select, source planNode, fromC s.sourceInfo = sourceInfo{desc} } - if err := s.initFields(parsed); err != nil { + aggregates, err := s.initFields(parsed) + if err != nil { return nil, err } @@ -326,22 +360,13 @@ func (p *Planner) SelectFromSource(parsed *parser.Select, source planNode, fromC return nil, err } - countPlans := []*countNode{} - for _, countItem := range parsed.Counts { - countNode, err := p.Count(&countItem, countItem.Name) - if err != nil { - return nil, err - } - countPlans = append(countPlans, countNode) - } - top := &selectTopNode{ source: s, render: p.render(parsed), limit: limitPlan, sort: sortPlan, group: groupPlan, - countPlans: countPlans, + aggregates: aggregates, } return top, nil } @@ -355,7 +380,8 @@ func (p *Planner) Select(parsed *parser.Select) (planNode, error) { groupBy := parsed.GroupBy s.renderInfo = &renderInfo{} - if err := s.initSource(parsed); err != nil { + aggregates, err := s.initSource(parsed) + if err != nil { return nil, err } @@ -374,22 +400,13 @@ func (p *Planner) Select(parsed *parser.Select) (planNode, error) { return nil, err } - countPlans := []*countNode{} - for _, countItem := range parsed.Counts { - countNode, err := p.Count(&countItem, countItem.Name) - if err != nil { - return nil, err - } - countPlans = append(countPlans, countNode) - } - top := &selectTopNode{ source: s, render: p.render(parsed), limit: limitPlan, sort: sortPlan, group: groupPlan, - countPlans: countPlans, + aggregates: aggregates, } return top, nil } diff --git a/query/graphql/planner/sum.go b/query/graphql/planner/sum.go new file mode 100644 index 0000000000..8e256decac --- /dev/null +++ b/query/graphql/planner/sum.go @@ -0,0 +1,152 @@ +// Copyright 2020 Source Inc. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +package planner + +import ( + "fmt" + + "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/db/base" + "github.com/sourcenetwork/defradb/query/graphql/parser" +) + +type sumNode struct { + p *Planner + plan planNode + + isFloat bool + sourceCollection string + sourceProperty string + virtualFieldId string +} + +func (p *Planner) Sum(sourceInfo *sourceInfo, field *parser.Field) (*sumNode, error) { + var sourceProperty string + var sourceCollection string + var isFloat bool + + source, err := field.GetAggregateSource() + if err != nil { + return nil, err + } + + if len(source) == 1 { + // If path length is one - we are summing an inline array + sourceCollection = source[0] + sourceProperty = "" + + fieldDescription, fieldDescriptionFound := sourceInfo.collectionDescription.GetField(sourceCollection) + if !fieldDescriptionFound { + return nil, fmt.Errorf("Unable to find field description for field: %s", sourceCollection) + } + + isFloat = fieldDescription.Kind == base.FieldKind_FLOAT_ARRAY + } else if len(source) == 2 { + // If path length is two, we are summing a group or a child relationship + sourceCollection = source[0] + sourceProperty = source[1] + + var childFieldDescription base.FieldDescription + if sourceCollection == parser.GroupFieldName { + // If the source collection is a group, then the description of the collection to sum is this object + fieldDescription, fieldDescriptionFound := sourceInfo.collectionDescription.GetField(sourceProperty) + if !fieldDescriptionFound { + return nil, fmt.Errorf("Unable to find field description for field: %s", sourceProperty) + } + childFieldDescription = fieldDescription + } else { + parentFieldDescription, parentFieldDescriptionFound := sourceInfo.collectionDescription.GetField(sourceCollection) + if !parentFieldDescriptionFound { + return nil, fmt.Errorf("Unable to find parent field description for field: %s", sourceCollection) + } + collectionDescription, err := p.getCollectionDesc(parentFieldDescription.Schema) + if err != nil { + return nil, err + } + fieldDescription, fieldDescriptionFound := collectionDescription.GetField(sourceProperty) + if !fieldDescriptionFound { + return nil, fmt.Errorf("Unable to find child field description for field: %s", sourceProperty) + } + childFieldDescription = fieldDescription + } + + isFloat = childFieldDescription.Kind == base.FieldKind_FLOAT + } else { + sourceCollection = "" + sourceProperty = "" + } + + return &sumNode{ + p: p, + isFloat: isFloat, + sourceCollection: sourceCollection, + sourceProperty: sourceProperty, + virtualFieldId: field.Name, + }, nil +} + +func (n *sumNode) Init() error { + return n.plan.Init() +} + +func (n *sumNode) Start() error { return n.plan.Start() } +func (n *sumNode) Spans(spans core.Spans) { n.plan.Spans(spans) } +func (n *sumNode) Close() error { return n.plan.Close() } +func (n *sumNode) Source() planNode { return n.plan } + +func (n *sumNode) Values() map[string]interface{} { + value := n.plan.Values() + + sum := float64(0) + + if child, hasProperty := value[n.sourceCollection]; hasProperty { + switch childCollection := child.(type) { + case []map[string]interface{}: + for _, childItem := range childCollection { + if childProperty, hasChildProperty := childItem[n.sourceProperty]; hasChildProperty { + switch v := childProperty.(type) { + case int64: + sum += float64(v) + case uint64: + sum += float64(v) + case float64: + sum += v + default: + // do nothing, cannot be summed + } + } + } + case []int64: + for _, childItem := range childCollection { + sum += float64(childItem) + } + case []float64: + for _, childItem := range childCollection { + sum += childItem + } + } + } + + var typedSum interface{} + if n.isFloat { + typedSum = sum + } else { + typedSum = int64(sum) + } + value[n.virtualFieldId] = typedSum + + return value +} + +func (n *sumNode) Next() (bool, error) { + return n.plan.Next() +} + +func (n *sumNode) SetPlan(p planNode) { n.plan = p } diff --git a/query/graphql/schema/generate.go b/query/graphql/schema/generate.go index 82dbc9ec87..c9726791fe 100644 --- a/query/graphql/schema/generate.go +++ b/query/graphql/schema/generate.go @@ -111,7 +111,9 @@ func (g *Generator) fromAST(document *ast.Document) ([]*gql.Object, error) { return nil, err } - g.genAggregateFields() + if err := g.genAggregateFields(); err != nil { + return nil, err + } // resolve types if err := g.manager.ResolveTypes(); err != nil { return nil, err @@ -421,14 +423,22 @@ func getRelationshipName(field *ast.FieldDefinition, hostName gql.ObjectConfig, return genRelationName(hostName.Name, targetName.Name()) } -func (g *Generator) genAggregateFields() { +func (g *Generator) genAggregateFields() error { for _, t := range g.typeDefs { - countField := g.genCountFieldConfig(t) + countField, err := g.genCountFieldConfig(t) + if err != nil { + return err + } t.AddFieldConfig(countField.Name, &countField) + + sumField := g.genSumFieldConfig(t) + t.AddFieldConfig(sumField.Name, &sumField) } + + return nil } -func (g *Generator) genCountFieldConfig(obj *gql.Object) gql.Field { +func (g *Generator) genCountFieldConfig(obj *gql.Object) (gql.Field, error) { inputCfg := gql.EnumConfig{ Name: genTypeName(obj, "CountArg"), Values: gql.EnumValueConfigMap{}, @@ -442,7 +452,10 @@ func (g *Generator) genCountFieldConfig(obj *gql.Object) gql.Field { inputCfg.Values[field.Name] = &gql.EnumValueConfig{Value: field.Name} } countType := gql.NewEnum(inputCfg) - g.manager.schema.AppendType(countType) + err := g.manager.schema.AppendType(countType) + if err != nil { + return gql.Field{}, err + } field := gql.Field{ Name: parser.CountFieldName, @@ -452,9 +465,89 @@ func (g *Generator) genCountFieldConfig(obj *gql.Object) gql.Field { }, } + return field, nil +} + +func (g *Generator) genSumFieldConfig(obj *gql.Object) gql.Field { + var sumType *gql.InputObject + + inputCfg := gql.InputObjectConfig{ + Name: genTypeName(obj, "SumArg"), + } + + inputCfg.Fields = (gql.InputObjectConfigFieldMapThunk)(func() (gql.InputObjectConfigFieldMap, error) { + fields := gql.InputObjectConfigFieldMap{} + + sumBaseArgType, isSumable := g.genSumBaseArgInput(obj) + if isSumable { + err := g.manager.schema.AppendType(sumBaseArgType) + if err != nil { + return gql.InputObjectConfigFieldMap{}, err + } + } + + for _, field := range obj.Fields() { + // we can only sum list items + listType, isList := field.Type.(*gql.List) + if !isList { + continue + } + + if listType.OfType == gql.Float || listType.OfType == gql.Int { + // If it is an inline scalar array then we require an empty object as an argument due to the lack of union input types + fields[field.Name] = &gql.InputObjectFieldConfig{ + Type: &gql.Object{}, + } + } else { + subSumType, isSubTypeSumable := g.manager.schema.TypeMap()[genTypeName(field.Type, "SumBaseArg")] + // If the item is not in the type map, it must contain no summable fields (e.g. no Int/Floats) + if !isSubTypeSumable { + continue + } + fields[field.Name] = &gql.InputObjectFieldConfig{ + Type: subSumType, + } + } + + } + + return fields, nil + }) + sumType = gql.NewInputObject(inputCfg) + g.manager.schema.AppendType(sumType) //this might resolve the thunk? Race issue? + + field := gql.Field{ + Name: parser.SumFieldName, + Type: gql.Float, + Args: gql.FieldConfigArgument{ + "field": newArgConfig(sumType), + }, + } return field } +func (g *Generator) genSumBaseArgInput(obj *gql.Object) (*gql.Enum, bool) { + inputCfg := gql.EnumConfig{ + Name: genTypeName(obj, "SumBaseArg"), + Values: gql.EnumValueConfigMap{}, + } + + hasSumableFields := false + // generate basic filter operator blocks for all the sumable types + for _, field := range obj.Fields() { + if field.Type == gql.Float || field.Type == gql.Int { + hasSumableFields = true + inputCfg.Values[field.Name] = &gql.EnumValueConfig{Value: field.Name} + } + } + + if !hasSumableFields { + return nil, false + } + + return gql.NewEnum(inputCfg), true +} + // Given a parsed ast.Node object, lookup the type in the TypeMap and return if its there // otherwise return an error // ast.Node, can either be a ast.Named type, a ast.List, or a ast.NonNull. diff --git a/query/graphql/schema/generate_test.go b/query/graphql/schema/generate_test.go index b0fc4bf8bd..3278e280ed 100644 --- a/query/graphql/schema/generate_test.go +++ b/query/graphql/schema/generate_test.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "reflect" + "strings" "testing" @@ -64,6 +65,10 @@ func Test_Generator_buildTypesFromAST_SingleScalarField(t *testing.T) { Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.String, @@ -121,6 +126,10 @@ func Test_Generator_buildTypesFromAST_SingleNonNullScalarField(t *testing.T) { Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.NewNonNull(gql.String), @@ -161,6 +170,10 @@ func Test_Generator_buildTypesFromAST_SingleListScalarField(t *testing.T) { Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.NewList(gql.String), @@ -201,6 +214,10 @@ func Test_Generator_buildTypesFromAST_SingleListNonNullScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.NewList(gql.NewNonNull(gql.String)), @@ -241,6 +258,10 @@ func Test_Generator_buildTypesFromAST_SingleNonNullListScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.NewNonNull(gql.NewList(gql.String)), @@ -281,6 +302,10 @@ func Test_Generator_buildTypesFromAST_SingleNonNullListNonNullScalarField(t *tes Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.NewNonNull(gql.NewList(gql.NewNonNull(gql.String))), @@ -326,6 +351,10 @@ func Test_Generator_buildTypesFromAST_MultiScalarField(t *testing.T) { Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.String, @@ -390,6 +419,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectSingleScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.String, @@ -417,6 +450,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectSingleScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "otherField": &gql.Field{ Name: "otherField", Type: gql.Boolean, @@ -463,6 +500,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectMultiScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.String, @@ -494,6 +535,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectMultiScalarField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "otherField": &gql.Field{ Name: "otherField", Type: gql.Boolean, @@ -531,6 +576,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectSingleObjectField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "myField": &gql.Field{ Name: "myField", Type: gql.String, @@ -571,6 +620,10 @@ func Test_Generator_buildTypesFromAST_MultiObjectSingleObjectField(t *testing.T) Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "otherField": &gql.Field{ Name: "otherField", Type: myObj, @@ -629,6 +682,10 @@ func Test_Generator_buildTypesFromAST_MissingObject(t *testing.T) { Name: "_count", Type: gql.Int, }, + "_sum": &gql.Field{ + Name: "_sum", + Type: gql.Float, + }, "otherField": &gql.Field{ Name: "otherField", Type: myObj,