Skip to content

Commit

Permalink
Merge branch 'main' into DX-50
Browse files Browse the repository at this point in the history
  • Loading branch information
Talent Zeng committed Sep 23, 2024
2 parents 78797b8 + ce39c99 commit 8064e06
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 85 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
* @openfga/dx @openfga/contractors-ides
README.md @openfga/product @openfga/community @openfga/dx
pkg/go/graph/* @openfga/backend
42 changes: 39 additions & 3 deletions pkg/go/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,48 @@ import (
"gonum.org/v1/gonum/graph/topo"
)

var ErrBuildingGraph = errors.New("cannot build graph")
var (
ErrBuildingGraph = errors.New("cannot build graph")
ErrQueryingGraph = errors.New("cannot query graph")
)

type DrawingDirection bool

const (
// DrawingDirectionListObjects is when terminal types have outgoing edges and no incoming edges.
DrawingDirectionListObjects DrawingDirection = true
DrawingDirectionCheck DrawingDirection = false
// DrawingDirectionCheck is when terminal types have incoming edges and no outgoing edges.
DrawingDirectionCheck DrawingDirection = false
)

type AuthorizationModelGraph struct {
*multi.DirectedGraph
drawingDirection DrawingDirection
ids NodeLabelsToIDs
}

func (g *AuthorizationModelGraph) GetDrawingDirection() DrawingDirection {
return g.drawingDirection
}

// GetNodeByLabel provides O(1) access to a node.
func (g *AuthorizationModelGraph) GetNodeByLabel(label string) (*AuthorizationModelNode, error) {
id, ok := g.ids[label]
if !ok {
return nil, fmt.Errorf("%w: node with label %s not found", ErrQueryingGraph, label)
}

node := g.Node(id)
if node == nil {
return nil, fmt.Errorf("%w: node with id %d not found", ErrQueryingGraph, id)
}

casted, ok := node.(*AuthorizationModelNode)
if !ok {
return nil, fmt.Errorf("%w: could not cast to AuthorizationModelNode", ErrQueryingGraph)
}

return casted, nil
}

// Reversed returns a full copy of the graph, but with the direction of the arrows flipped.
Expand Down Expand Up @@ -57,9 +87,15 @@ func (g *AuthorizationModelGraph) Reversed() (*AuthorizationModelGraph, error) {
}
}

// Make a brand new copy of the map.
copyIDs := make(NodeLabelsToIDs, len(g.ids))
for k, v := range g.ids {
copyIDs[k] = v
}

multigraph, ok := graphBuilder.DirectedMultigraphBuilder.(*multi.DirectedGraph)
if ok {
return &AuthorizationModelGraph{multigraph, !g.drawingDirection}, nil
return &AuthorizationModelGraph{multigraph, !g.drawingDirection, copyIDs}, nil
}

return nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
Expand Down
28 changes: 13 additions & 15 deletions pkg/go/graph/graph_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,28 @@ import (
"gonum.org/v1/gonum/graph/multi"
)

type NodeLabelsToIDs map[string]int64

type AuthorizationModelGraphBuilder struct {
graph.DirectedMultigraphBuilder

ids map[string]int64 // nodes: unique labels to ids. Used to find nodes by label.
ids NodeLabelsToIDs // nodes: unique labels to ids. Used to find nodes by label.
}

// NewAuthorizationModelGraph builds an authorization model in graph form.
// For example, types such as `group`, usersets such as `group#member` and wildcards `group:*` are encoded as nodes.
//
// The edges are defined by the assignments, e.g.
// `define viewer: [group]` defines an edge from group to document#viewer.
// Conditions are not encoded in the graph,
// and the two edges in an exclusion are not distinguished.
// By default, the graph is drawn from bottom to top (i.e. terminal types have outgoing edges and no incoming edges).
// Conditions are not encoded in the graph.
func NewAuthorizationModelGraph(model *openfgav1.AuthorizationModel) (*AuthorizationModelGraph, error) {
res, err := parseModel(model)
res, ids, err := parseModel(model)
if err != nil {
return nil, err
}

return &AuthorizationModelGraph{res, DrawingDirectionListObjects}, nil
return &AuthorizationModelGraph{res, DrawingDirectionListObjects, ids}, nil
}

func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, error) {
func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, NodeLabelsToIDs, error) {
graphBuilder := &AuthorizationModelGraphBuilder{
multi.NewDirectedGraph(), map[string]int64{},
}
Expand Down Expand Up @@ -67,10 +66,10 @@ func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, erro

multigraph, ok := graphBuilder.DirectedMultigraphBuilder.(*multi.DirectedGraph)
if ok {
return multigraph, nil
return multigraph, graphBuilder.ids, nil
}

return nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
return nil, nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
}

func checkRewrite(graphBuilder *AuthorizationModelGraphBuilder, parentNode *AuthorizationModelNode, model *openfgav1.AuthorizationModel, rewrite *openfgav1.Userset, typeDef *openfgav1.TypeDefinition, relation string) {
Expand Down Expand Up @@ -248,10 +247,7 @@ func (g *AuthorizationModelGraphBuilder) HasEdge(from, to graph.Node, edgeType E
}

iter := g.Lines(from.ID(), to.ID())
for {
if !iter.Next() {
return false
}
for iter.Next() {
l := iter.Line()
edge, ok := l.(*AuthorizationModelEdge)
if !ok {
Expand All @@ -261,6 +257,8 @@ func (g *AuthorizationModelGraphBuilder) HasEdge(from, to graph.Node, edgeType E
return true
}
}

return false
}

func typeAndRelationExists(model *openfgav1.AuthorizationModel, typeName, relation string) bool {
Expand Down
53 changes: 30 additions & 23 deletions pkg/go/graph/graph_edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,41 @@ type AuthorizationModelEdge struct {

var _ encoding.Attributer = (*AuthorizationModelEdge)(nil)

func (n *AuthorizationModelEdge) Attributes() []encoding.Attribute {
var attrs []encoding.Attribute

if n.edgeType == DirectEdge {
attrs = append(attrs, encoding.Attribute{
Key: "label",
Value: "direct",
})
}

if n.edgeType == ComputedEdge {
attrs = append(attrs, encoding.Attribute{
Key: "style",
Value: "dashed",
})
}
func (n *AuthorizationModelEdge) EdgeType() EdgeType {
return n.edgeType
}

if n.edgeType == TTUEdge {
func (n *AuthorizationModelEdge) Attributes() []encoding.Attribute {
switch n.edgeType {
case DirectEdge:
return []encoding.Attribute{
{
Key: "label",
Value: "direct",
},
}
case ComputedEdge:
return []encoding.Attribute{
{
Key: "style",
Value: "dashed",
},
}
case TTUEdge:
headLabelAttrValue := n.conditionedOn
if headLabelAttrValue == "" {
headLabelAttrValue = "missing"
}

attrs = append(attrs, encoding.Attribute{
Key: "headlabel",
Value: headLabelAttrValue,
})
return []encoding.Attribute{
{
Key: "headlabel",
Value: headLabelAttrValue,
},
}
case RewriteEdge:
return []encoding.Attribute{}
default:
return []encoding.Attribute{}
}

return attrs
}
78 changes: 76 additions & 2 deletions pkg/go/graph/graph_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -71,10 +72,8 @@ rankdir=TB
model := language.MustTransformDSLToProto(testCase.model)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)
require.Equal(t, DrawingDirectionListObjects, graph.drawingDirection)
reversedGraph, err := graph.Reversed()
require.NoError(t, err)
require.Equal(t, DrawingDirectionCheck, reversedGraph.drawingDirection)
actualDOT := reversedGraph.GetDOT()
actualSorted := getSorted(actualDOT)
expectedSorted := getSorted(testCase.expectedOutput)
Expand All @@ -85,3 +84,78 @@ rankdir=TB
})
}
}

func TestGetDrawingDirection(t *testing.T) {
t.Parallel()
model := language.MustTransformDSLToProto(`
model
schema 1.1
type user
type company
relations
define member: [user]`)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)
require.Equal(t, DrawingDirectionListObjects, graph.GetDrawingDirection())
reversedGraph, err := graph.Reversed()
require.NoError(t, err)
require.Equal(t, DrawingDirectionCheck, reversedGraph.GetDrawingDirection())
}

func TestGetNodeByLabel(t *testing.T) {
t.Parallel()
model := language.MustTransformDSLToProto(`
model
schema 1.1
type user
type company
relations
define member: [user with cond, user:* with cond]
define owner: [user]
define approved_member: member or owner
type group
relations
define approved_member: [user]
type license
relations
define active_member: approved_member from owner
define owner: [company, group]`)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)

testCases := []struct {
label string
expectedFound bool
}{
// found
{"user", true},
{"user:*", true},
{"company", true},
{"company#member", true},
{"company#owner", true},
{"company#approved_member", true},
{"group", true},
{"group#approved_member", true},
{"license", true},
{"license#active_member", true},
{"license#owner", true},
// not found
{"unknown", false},
{"unknown#unknown", false},
{"user with cond", false},
{"user:* with cond", false},
}
for i, testCase := range testCases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel()
node, err := graph.GetNodeByLabel(testCase.label)
if testCase.expectedFound {
require.NoError(t, err)
require.NotNil(t, node)
} else {
require.ErrorIs(t, err, ErrQueryingGraph)
require.Nil(t, node)
}
})
}
}
Loading

0 comments on commit 8064e06

Please sign in to comment.