diff --git a/gql/parser.go b/gql/parser.go index 2d59841c3a2..23ac00306e4 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -1880,6 +1880,40 @@ func parseGroupby(it *lex.ItemIterator, gq *GraphQuery) error { return nil } +func parseType(it *lex.ItemIterator, gq *GraphQuery) error { + it.Next() + if it.Item().Typ != itemLeftRound { + return it.Item().Errorf("Expected a left round after type") + } + + it.Next() + if it.Item().Typ != itemName { + return it.Item().Errorf("Expected a type name inside type directive") + } + typeName := it.Item().Val + + it.Next() + if it.Item().Typ != itemRightRound { + return it.Item().Errorf("Expected ) after the type name in type directive") + } + + // For now @type(TypeName) is equivalent of filtering using the type function. + // Later the type declarations will be used to ensure that the fields inside + // each block correspond to the specified type. + gq.Filter = &FilterTree{ + Func: &Function{ + Name: "type", + Args: []Arg{ + Arg{ + Value: typeName, + }, + }, + }, + } + + return nil +} + // parseFilter parses the filter directive to produce a QueryFilter / parse tree. func parseFilter(it *lex.ItemIterator) (*FilterTree, error) { it.Next() @@ -2110,6 +2144,11 @@ func parseDirective(it *lex.ItemIterator, curp *GraphQuery) error { } curp.IsGroupby = true parseGroupby(it, curp) + case "type": + err := parseType(it, curp) + if err != nil { + return err + } default: return item.Errorf("Unknown directive [%s]", item.Val) } diff --git a/gql/parser_test.go b/gql/parser_test.go index ea221bb49c1..082cbf49b26 100644 --- a/gql/parser_test.go +++ b/gql/parser_test.go @@ -4527,3 +4527,81 @@ func TestTypeInFilter(t *testing.T) { require.Equal(t, 1, len(gq.Query[0].Filter.Func.Args)) require.Equal(t, "Person", gq.Query[0].Filter.Func.Args[0].Value) } + +func TestTypeFilterInPredicate(t *testing.T) { + q := ` + query { + me(func: uid(0x01)) { + friend @filter(type(Person)) { + name + } + } + }` + gq, err := Parse(Request{Str: q}) + require.NoError(t, err) + require.Equal(t, 1, len(gq.Query)) + require.Equal(t, "uid", gq.Query[0].Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children)) + require.Equal(t, "friend", gq.Query[0].Children[0].Attr) + + require.Equal(t, "type", gq.Query[0].Children[0].Filter.Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children[0].Filter.Func.Args)) + require.Equal(t, "Person", gq.Query[0].Children[0].Filter.Func.Args[0].Value) + + require.Equal(t, 1, len(gq.Query[0].Children[0].Children)) + require.Equal(t, "name", gq.Query[0].Children[0].Children[0].Attr) +} + +func TestTypeInPredicate(t *testing.T) { + q := ` + query { + me(func: uid(0x01)) { + friend @type(Person) { + name + } + } + }` + gq, err := Parse(Request{Str: q}) + require.NoError(t, err) + require.Equal(t, 1, len(gq.Query)) + require.Equal(t, "uid", gq.Query[0].Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children)) + require.Equal(t, "friend", gq.Query[0].Children[0].Attr) + + require.Equal(t, "type", gq.Query[0].Children[0].Filter.Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children[0].Filter.Func.Args)) + require.Equal(t, "Person", gq.Query[0].Children[0].Filter.Func.Args[0].Value) + + require.Equal(t, 1, len(gq.Query[0].Children[0].Children)) + require.Equal(t, "name", gq.Query[0].Children[0].Children[0].Attr) +} + +func TestMultipleTypeDirectives(t *testing.T) { + q := ` + query { + me(func: uid(0x01)) { + friend @type(Person) { + pet @type(Animal) { + name + } + } + } + }` + gq, err := Parse(Request{Str: q}) + require.NoError(t, err) + require.Equal(t, 1, len(gq.Query)) + require.Equal(t, "uid", gq.Query[0].Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children)) + require.Equal(t, "friend", gq.Query[0].Children[0].Attr) + + require.Equal(t, "type", gq.Query[0].Children[0].Filter.Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children[0].Filter.Func.Args)) + require.Equal(t, "Person", gq.Query[0].Children[0].Filter.Func.Args[0].Value) + + require.Equal(t, 1, len(gq.Query[0].Children[0].Children)) + require.Equal(t, "pet", gq.Query[0].Children[0].Children[0].Attr) + + require.Equal(t, "type", gq.Query[0].Children[0].Children[0].Filter.Func.Name) + require.Equal(t, 1, len(gq.Query[0].Children[0].Children[0].Filter.Func.Args)) + require.Equal(t, "Animal", gq.Query[0].Children[0].Children[0].Filter.Func.Args[0].Value) +} diff --git a/query/common_test.go b/query/common_test.go index e128fd186e2..ae3299f0943 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -202,6 +202,7 @@ symbol : string @index(exact) . room : string @index(term) . office.room : [uid] . best_friend : uid . +pet : [uid] . ` func populateCluster(t *testing.T) { @@ -211,6 +212,12 @@ func populateCluster(t *testing.T) { addTriplesToCluster(t, ` <1> "Michonne" . + <2> "King Lear" . + <3> "Margaret" . + <4> "Leonard" . + <5> "Garfield" . + <6> "Bear" . + <7> "Nemo" . <23> "Rick Grimes" . <24> "Glenn Rhee" . <25> "Daryl Dixon" . @@ -425,6 +432,15 @@ func populateCluster(t *testing.T) { <2> "Person" . <3> "Person" . <4> "Person" . + <5> "Animal" . + <6> "Animal" . + + <2> <5> . + <3> <6> . + <4> <7> . + + <2> <3> . + <2> <4> . `) addGeoPointToCluster(t, 1, "loc", []float64{1.1, 2.0}) diff --git a/query/query3_test.go b/query/query3_test.go index 5de5ec985c7..810a120acbb 100644 --- a/query/query3_test.go +++ b/query/query3_test.go @@ -1913,3 +1913,34 @@ func TestTypeFilterUnknownType(t *testing.T) { js := processQueryNoErr(t, query) require.JSONEq(t, `{"data": {"me":[]}}`, js) } + +func TestTypeDirectiveInPredicate(t *testing.T) { + query := ` + { + me(func: uid(0x2)) { + enemy @type(Person) { + name + } + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[{"enemy":[{"name":"Margaret"}, {"name":"Leonard"}]}]}}`, js) +} + +func TestMultipleTypeDirectivesInPredicate(t *testing.T) { + query := ` + { + me(func: uid(0x2)) { + enemy @type(Person) { + name + pet @type(Animal) { + name + } + } + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[{"enemy":[{"name":"Margaret", "pet":[{"name":"Bear"}]}, {"name":"Leonard"}]}]}}`, js) +}