diff --git a/gql/parser.go b/gql/parser.go index 2dbf9dcf347..408742d897d 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -2185,6 +2185,38 @@ func parseVarList(it *lex.ItemIterator, gq *GraphQuery) (int, error) { return count, nil } +func parseTypeList(it *lex.ItemIterator, gq *GraphQuery) error { + typeList := it.Item().Val + expectArg := false +loop: + for it.Next() { + item := it.Item() + switch item.Typ { + case itemRightRound: + it.Prev() + break loop + case itemComma: + if expectArg { + return item.Errorf("Expected a variable but got comma") + } + expectArg = true + case itemName: + if !expectArg { + return item.Errorf("Expected a variable but got comma") + } + typeList = fmt.Sprintf("%s,%s", typeList, item.Val) + expectArg = false + default: + return item.Errorf("Unexpected token %s when reading a type list", item.Val) + } + } + if expectArg { + return it.Item().Errorf("Unnecessary comma in val()") + } + gq.Expand = typeList + return nil +} + func parseDirective(it *lex.ItemIterator, curp *GraphQuery) error { valid := true it.Prev() @@ -2808,7 +2840,9 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { case "_reverse_": return item.Errorf("Argument _reverse_ has been deprecated") default: - return item.Errorf("Invalid argument %v in expand()", item.Val) + if err := parseTypeList(it, child); err != nil { + return err + } } it.Next() // Consume ')' gq.Children = append(gq.Children, child) diff --git a/gql/parser_test.go b/gql/parser_test.go index 03766178d0b..5903dd1dc1f 100644 --- a/gql/parser_test.go +++ b/gql/parser_test.go @@ -190,6 +190,34 @@ func TestParseQueryExpandReverse(t *testing.T) { require.Contains(t, err.Error(), "Argument _reverse_ has been deprecated") } +func TestParseQueryExpandType(t *testing.T) { + query := ` + { + var(func: uid( 0x0a)) { + friends { + expand(Person) + } + } + } +` + _, err := Parse(Request{Str: query}) + require.NoError(t, err) +} + +func TestParseQueryExpandMultipleTypes(t *testing.T) { + query := ` + { + var(func: uid( 0x0a)) { + friends { + expand(Person, Relative) + } + } + } +` + _, err := Parse(Request{Str: query}) + require.NoError(t, err) +} + func TestParseQueryAliasListPred(t *testing.T) { query := ` { @@ -4823,3 +4851,23 @@ func TestTypeFilterInPredicate(t *testing.T) { require.Equal(t, 1, len(gq.Query[0].Children[0].Children)) require.Equal(t, "name", gq.Query[0].Children[0].Children[0].Attr) } + +func TestParseExpandType(t *testing.T) { + query := ` + { + var(func: has(name)) { + expand(Person,Animal) { + uid + } + } + } +` + gq, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Equal(t, 1, len(gq.Query)) + require.Equal(t, 1, len(gq.Query[0].Children)) + require.Equal(t, "expand", gq.Query[0].Children[0].Attr) + require.Equal(t, "Person,Animal", gq.Query[0].Children[0].Expand) + require.Equal(t, 1, len(gq.Query[0].Children[0].Children)) + require.Equal(t, "uid", gq.Query[0].Children[0].Children[0].Attr) +} diff --git a/query/common_test.go b/query/common_test.go index 1d43a79f59d..ed1c029b28c 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -199,6 +199,11 @@ type CarModel { <~previous_model> } +type Object { + name + owner +} + type SchoolInfo { name abbr @@ -266,6 +271,7 @@ newname : string @index(exact, term) . newage : int . boss : uid . newfriend : [uid] . +owner : [uid] . ` func populateCluster() { @@ -545,11 +551,14 @@ func populateCluster() { <201> "CarModel" . <201> <200> . + <202> "Car" . <202> "Toyota" . <202> "2009" . <202> "Prius" . <202> "プリウス"@jp . + <202> <203> . <202> "CarModel" . + <202> "Object" . # data for regexp testing _:luke "Luke" . diff --git a/query/query.go b/query/query.go index 03151391865..6d1090839cf 100644 --- a/query/query.go +++ b/query/query.go @@ -1829,9 +1829,14 @@ func expandSubgraph(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { preds = getPredicatesFromTypes(types) default: - span.Annotate(nil, "expand default") - // We already have the predicates populated from the var. - preds = getPredsFromVals(child.ExpandPreds) + if len(child.ExpandPreds) > 0 { + span.Annotate(nil, "expand default") + // We already have the predicates populated from the var. + preds = getPredsFromVals(child.ExpandPreds) + } else { + types := strings.Split(child.Params.Expand, ",") + preds = getPredicatesFromTypes(types) + } } preds = uniquePreds(preds) @@ -1841,7 +1846,10 @@ func expandSubgraph(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { Attr: pred, } temp.Params = child.Params - temp.Params.ExpandAll = child.Params.Expand == "_all_" + // TODO(martinmr): simplify this condition once _reverse_ and _forward_ + // are removed + temp.Params.ExpandAll = child.Params.Expand != "_reverse_" && + child.Params.Expand != "_forward_" temp.Params.ParentVars = make(map[string]varValue) for k, v := range child.Params.ParentVars { temp.Params.ParentVars[k] = v diff --git a/query/query0_test.go b/query/query0_test.go index da9e82e24f6..7360cba3357 100644 --- a/query/query0_test.go +++ b/query/query0_test.go @@ -153,9 +153,9 @@ func TestQueryCountEmptyNames(t *testing.T) { {in: `{q(func: has(name)) @filter(eq(name, "")) {count(uid)}}`, out: `{"data":{"q": [{"count":2}]}}`}, {in: `{q(func: has(name)) @filter(gt(name, "")) {count(uid)}}`, - out: `{"data":{"q": [{"count":46}]}}`}, + out: `{"data":{"q": [{"count":47}]}}`}, {in: `{q(func: has(name)) @filter(ge(name, "")) {count(uid)}}`, - out: `{"data":{"q": [{"count":48}]}}`}, + out: `{"data":{"q": [{"count":49}]}}`}, {in: `{q(func: has(name)) @filter(lt(name, "")) {count(uid)}}`, out: `{"data":{"q": [{"count":0}]}}`}, {in: `{q(func: has(name)) @filter(le(name, "")) {count(uid)}}`, @@ -166,7 +166,7 @@ func TestQueryCountEmptyNames(t *testing.T) { out: `{"data":{"q": [{"count":2}]}}`}, // NOTE: match with empty string filters values greater than the max distance. {in: `{q(func: has(name)) @filter(match(name, "", 8)) {count(uid)}}`, - out: `{"data":{"q": [{"count":28}]}}`}, + out: `{"data":{"q": [{"count":29}]}}`}, {in: `{q(func: has(name)) @filter(uid_in(name, "")) {count(uid)}}`, failure: `Value "" in uid_in is not a number`}, } diff --git a/query/query4_test.go b/query/query4_test.go index 1f53ef370e2..a24006cd2c2 100644 --- a/query/query4_test.go +++ b/query/query4_test.go @@ -306,7 +306,34 @@ func TestTypeExpandLang(t *testing.T) { }` js := processQueryNoErr(t, query) require.JSONEq(t, `{"data": {"q":[ - {"make":"Toyota","model":"Prius", "model@jp":"プリウス", "year":2009}]}}`, js) + {"name": "Car", "make":"Toyota","model":"Prius", "model@jp":"プリウス", "year":2009, + "owner": [{"uid": "0xcb"}]}]}}`, js) +} + +func TestTypeExpandExplicitType(t *testing.T) { + query := `{ + q(func: eq(make, "Toyota")) { + expand(Object) { + uid + } + } + }` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"q":[{"name":"Car", "owner": [{"uid": "0xcb"}]}]}}`, js) +} + +func TestTypeExpandMultipleExplicitTypes(t *testing.T) { + query := `{ + q(func: eq(make, "Toyota")) { + expand(CarModel, Object) { + uid + } + } + }` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"q":[ + {"name": "Car", "make":"Toyota","model":"Prius", "model@jp":"プリウス", "year":2009, + "owner": [{"uid": "0xcb"}]}]}}`, js) } // Test Related to worker based pagination. diff --git a/query/query_facets_test.go b/query/query_facets_test.go index f0d13aabd8b..df1af008032 100644 --- a/query/query_facets_test.go +++ b/query/query_facets_test.go @@ -998,6 +998,6 @@ func TestTypeExpandFacets(t *testing.T) { }` js := processQueryNoErr(t, query) require.JSONEq(t, `{"data": {"q":[ - {"make":"Toyota","model":"Prius", "model@jp":"プリウス", "model|type":"Electric", - "year":2009}]}}`, js) + {"name": "Car", "make":"Toyota","model":"Prius", "model@jp":"プリウス", + "model|type":"Electric", "year":2009, "owner": [{"uid": "0xcb"}]}]}}`, js) }