diff --git a/pkg/go/graph/graph_builder.go b/pkg/go/graph/graph_builder.go index b30b8e97..e1c7abb1 100644 --- a/pkg/go/graph/graph_builder.go +++ b/pkg/go/graph/graph_builder.go @@ -135,7 +135,7 @@ func parseThis(graphBuilder *AuthorizationModelGraphBuilder, parentNode graph.No if directlyRelatedDef.GetWildcard() != nil { // direct assignment to wildcard assignableWildcard := directlyRelatedDef.GetType() + ":*" - curNode = graphBuilder.GetOrAddNode(assignableWildcard, assignableWildcard, SpecificType) + curNode = graphBuilder.GetOrAddNode(assignableWildcard, assignableWildcard, SpecificTypeWildcard) } if directlyRelatedDef.GetRelation() != "" { diff --git a/pkg/go/graph/graph_node.go b/pkg/go/graph/graph_node.go index 3204effe..0560a089 100644 --- a/pkg/go/graph/graph_node.go +++ b/pkg/go/graph/graph_node.go @@ -11,6 +11,7 @@ const ( SpecificType NodeType = 0 // e.g. `group` SpecificTypeAndRelation NodeType = 1 // e.g. `group#viewer` OperatorNode NodeType = 2 // e.g. union + SpecificTypeWildcard NodeType = 3 // e.g. `group:*` ) type AuthorizationModelNode struct { diff --git a/pkg/go/graph/graph_test.go b/pkg/go/graph/graph_test.go index 7cdf35d5..e313b454 100644 --- a/pkg/go/graph/graph_test.go +++ b/pkg/go/graph/graph_test.go @@ -159,3 +159,75 @@ func TestGetNodeByLabel(t *testing.T) { }) } } + +func TestGetNodeTypes(t *testing.T) { + t.Parallel() + model := language.MustTransformDSLToProto(` + model + schema 1.1 + type user + type group + relations + define member: [user] + type company + relations + define wildcard: [user:*] + define direct: [user] + define userset: [group#member] + define intersectionRelation: wildcard and direct + define unionRelation: wildcard or direct + define differenceRelation: wildcard but not direct`) + graph, err := NewAuthorizationModelGraph(model) + require.NoError(t, err) + + testCases := []struct { + label string + expectedNodeType NodeType + }{ + {"user", SpecificType}, + {"user:*", SpecificTypeWildcard}, + {"group", SpecificType}, + {"group#member", SpecificTypeAndRelation}, + {"company", SpecificType}, + {"company#wildcard", SpecificTypeAndRelation}, + {"company#direct", SpecificTypeAndRelation}, + {"company#userset", SpecificTypeAndRelation}, + {"company#intersectionRelation", SpecificTypeAndRelation}, + {"company#unionRelation", SpecificTypeAndRelation}, + {"company#differenceRelation", SpecificTypeAndRelation}, + } + for _, testCase := range testCases { + t.Run(testCase.label, func(t *testing.T) { + t.Parallel() + node, err := graph.GetNodeByLabel(testCase.label) + require.NoError(t, err) + require.NotNil(t, node) + require.Equal(t, testCase.expectedNodeType, node.NodeType(), "expected node type %d but got %d", testCase.expectedNodeType, node.NodeType()) + }) + } + + // testing the operator nodes is not so straightforward... + var unionNodes, differenceNodes, intersectionNodes []*AuthorizationModelNode + + iterNodes := graph.Nodes() + for iterNodes.Next() { + node, ok := iterNodes.Node().(*AuthorizationModelNode) + require.True(t, ok) + if node.nodeType != OperatorNode { + continue + } + + switch node.label { + case "union": + unionNodes = append(unionNodes, node) + case "intersection": + intersectionNodes = append(intersectionNodes, node) + case "exclusion": + differenceNodes = append(differenceNodes, node) + } + } + + require.Len(t, unionNodes, 1) + require.Len(t, differenceNodes, 1) + require.Len(t, intersectionNodes, 1) +}