Skip to content

Commit

Permalink
graph: Pass parsed odata request to the identity backend
Browse files Browse the repository at this point in the history
In preparation for some more advanced queries pass the parse odata request
tVo the identity backend methods instead of the raw url.Values{}. This also
add some helpers for validating $expand and $search queries to reject
some unsupported queries.

Also remove support for `$select=memberOf` and `$select=drive|drives` queries
and stick to the technically correct `$expand=...`.
  • Loading branch information
rhafer committed Feb 8, 2023
1 parent 25d2a2b commit 26f7523
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 74 deletions.
5 changes: 3 additions & 2 deletions services/graph/pkg/identity/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/url"

"github.com/CiscoM31/godata"
cs3 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
libregraph "github.com/owncloud/libre-graph-api-go"
"github.com/owncloud/ocis/v2/services/graph/pkg/service/v0/errorcode"
Expand All @@ -25,8 +26,8 @@ type Backend interface {
DeleteUser(ctx context.Context, nameOrID string) error
// UpdateUser applies changes to given user, identified by username or id
UpdateUser(ctx context.Context, nameOrID string, user libregraph.User) (*libregraph.User, error)
GetUser(ctx context.Context, nameOrID string, queryParam url.Values) (*libregraph.User, error)
GetUsers(ctx context.Context, queryParam url.Values) ([]*libregraph.User, error)
GetUser(ctx context.Context, nameOrID string, oreq *godata.GoDataRequest) (*libregraph.User, error)
GetUsers(ctx context.Context, oreq *godata.GoDataRequest) ([]*libregraph.User, error)

// CreateGroup creates the supplied group in the identity backend.
CreateGroup(ctx context.Context, group libregraph.Group) (*libregraph.Group, error)
Expand Down
15 changes: 10 additions & 5 deletions services/graph/pkg/identity/cs3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/url"

"github.com/CiscoM31/godata"
cs3group "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1"
cs3user "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
cs3rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
Expand Down Expand Up @@ -39,7 +40,8 @@ func (i *CS3) UpdateUser(ctx context.Context, nameOrID string, user libregraph.U
return nil, errNotImplemented
}

func (i *CS3) GetUser(ctx context.Context, userID string, queryParam url.Values) (*libregraph.User, error) {
// GetUser implements the Backend Interface.
func (i *CS3) GetUser(ctx context.Context, userID string, _ *godata.GoDataRequest) (*libregraph.User, error) {
logger := i.Logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "cs3").Msg("GetUser")
client, err := pool.GetGatewayServiceClient(i.Config.Address, i.Config.GetRevaOptions()...)
Expand Down Expand Up @@ -67,7 +69,8 @@ func (i *CS3) GetUser(ctx context.Context, userID string, queryParam url.Values)
return CreateUserModelFromCS3(res.User), nil
}

func (i *CS3) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregraph.User, error) {
// GetUsers implements the Backend Interface.
func (i *CS3) GetUsers(ctx context.Context, oreq *godata.GoDataRequest) ([]*libregraph.User, error) {
logger := i.Logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "cs3").Msg("GetUsers")
client, err := pool.GetGatewayServiceClient(i.Config.Address, i.Config.GetRevaOptions()...)
Expand All @@ -76,9 +79,9 @@ func (i *CS3) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregrap
return nil, errorcode.New(errorcode.ServiceNotAvailable, err.Error())
}

search := queryParam.Get("search")
if search == "" {
search = queryParam.Get("$search")
search, err := GetSearchValues(oreq.Query)
if err != nil {
return nil, err
}

res, err := client.FindUsers(ctx, &cs3user.FindUsersRequest{
Expand Down Expand Up @@ -107,6 +110,7 @@ func (i *CS3) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregrap
return users, nil
}

// GetGroups implements the Backend Interface.
func (i *CS3) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregraph.Group, error) {
logger := i.Logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "cs3").Msg("GetGroups")
Expand Down Expand Up @@ -153,6 +157,7 @@ func (i *CS3) CreateGroup(ctx context.Context, group libregraph.Group) (*libregr
return nil, errorcode.New(errorcode.NotSupported, "not implemented")
}

// GetGroup implements the Backend Interface.
func (i *CS3) GetGroup(ctx context.Context, groupID string, queryParam url.Values) (*libregraph.Group, error) {
logger := i.Logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "cs3").Msg("GetGroup")
Expand Down
37 changes: 24 additions & 13 deletions services/graph/pkg/identity/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strings"

"github.com/CiscoM31/godata"
"github.com/go-ldap/ldap/v3"
"github.com/gofrs/uuid"
libregraph "github.com/owncloud/libre-graph-api-go"
Expand Down Expand Up @@ -369,20 +368,27 @@ func (i *LDAP) getLDAPUserByFilter(filter string) (*ldap.Entry, error) {
return i.searchLDAPEntryByFilter(i.userBaseDN, attrs, filter)
}

func (i *LDAP) GetUser(ctx context.Context, nameOrID string, queryParam url.Values) (*libregraph.User, error) {
// GetUser implements the Backend Interface.
func (i *LDAP) GetUser(ctx context.Context, nameOrID string, oreq *godata.GoDataRequest) (*libregraph.User, error) {
logger := i.logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "ldap").Msg("GetUser")

e, err := i.getLDAPUserByNameOrID(nameOrID)
if err != nil {
return nil, err
}

u := i.createUserModelFromLDAP(e)
if u == nil {
return nil, ErrNotFound
}
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") {

exp, err := GetExpandValues(oreq.Query)
if err != nil {
return nil, err
}

if slices.Contains(exp, "memberOf") {
userGroups, err := i.getGroupsForUser(e.DN)
if err != nil {
return nil, err
Expand All @@ -392,14 +398,21 @@ func (i *LDAP) GetUser(ctx context.Context, nameOrID string, queryParam url.Valu
return u, nil
}

func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregraph.User, error) {
// GetUsers implements the Backend Interface.
func (i *LDAP) GetUsers(ctx context.Context, oreq *godata.GoDataRequest) ([]*libregraph.User, error) {
logger := i.logger.SubloggerWithRequestID(ctx)
logger.Debug().Str("backend", "ldap").Msg("GetUsers")

search := queryParam.Get("search")
if search == "" {
search = queryParam.Get("$search")
search, err := GetSearchValues(oreq.Query)
if err != nil {
return nil, err
}

exp, err := GetExpandValues(oreq.Query)
if err != nil {
return nil, err
}

var userFilter string
if search != "" {
search = ldap.EscapeFilter(search)
Expand Down Expand Up @@ -439,14 +452,12 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra
users := make([]*libregraph.User, 0, len(res.Entries))

for _, e := range res.Entries {
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
u := i.createUserModelFromLDAP(e)
// Skip invalid LDAP users
if u == nil {
continue
}
if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") {
if slices.Contains(exp, "memberOf") {
userGroups, err := i.getGroupsForUser(e.DN)
if err != nil {
return nil, err
Expand Down
52 changes: 24 additions & 28 deletions services/graph/pkg/identity/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"
"testing"

"github.com/CiscoM31/godata"
"github.com/go-ldap/ldap/v3"
libregraph "github.com/owncloud/libre-graph-api-go"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
Expand Down Expand Up @@ -172,23 +173,24 @@ func TestGetUser(t *testing.T) {
nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock")))
b, _ := getMockedBackend(lm, lconfig, &logger)

queryParamExpand := url.Values{
"$expand": []string{"memberOf"},
}
queryParamSelect := url.Values{
"$select": []string{"memberOf"},
odataReqDefault, err := godata.ParseRequest(context.Background(), "",
url.Values{})
if err != nil {
t.Errorf("Expected success got '%s'", err.Error())
}
_, err := b.GetUser(context.Background(), "fred", nil)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())

odataReqExpand, err := godata.ParseRequest(context.Background(), "",
url.Values{"$expand": []string{"memberOf"}})
if err != nil {
t.Errorf("Expected success got '%s'", err.Error())
}

_, err = b.GetUser(context.Background(), "fred", queryParamExpand)
_, err = b.GetUser(context.Background(), "fred", odataReqDefault)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

_, err = b.GetUser(context.Background(), "fred", queryParamSelect)
_, err = b.GetUser(context.Background(), "fred", odataReqExpand)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}
Expand All @@ -199,17 +201,12 @@ func TestGetUser(t *testing.T) {
Return(
&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
_, err = b.GetUser(context.Background(), "fred", nil)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

_, err = b.GetUser(context.Background(), "fred", queryParamExpand)
_, err = b.GetUser(context.Background(), "fred", odataReqDefault)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

_, err = b.GetUser(context.Background(), "fred", queryParamSelect)
_, err = b.GetUser(context.Background(), "fred", odataReqExpand)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}
Expand All @@ -224,21 +221,14 @@ func TestGetUser(t *testing.T) {
nil)

b, _ = getMockedBackend(lm, lconfig, &logger)
u, err := b.GetUser(context.Background(), "user", nil)
u, err := b.GetUser(context.Background(), "user", odataReqDefault)
if err != nil {
t.Errorf("Expected GetUser to succeed. Got %s", err.Error())
} else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) {
t.Errorf("Expected GetUser to return a valid user")
}

u, err = b.GetUser(context.Background(), "user", queryParamExpand)
if err != nil {
t.Errorf("Expected GetUser to succeed. Got %s", err.Error())
} else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) {
t.Errorf("Expected GetUser to return a valid user")
}

u, err = b.GetUser(context.Background(), "user", queryParamSelect)
u, err = b.GetUser(context.Background(), "user", odataReqExpand)
if err != nil {
t.Errorf("Expected GetUser to succeed. Got %s", err.Error())
} else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) {
Expand Down Expand Up @@ -266,16 +256,22 @@ func TestGetUsers(t *testing.T) {
lm := &mocks.Client{}
lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock")))

odataReqDefault, err := godata.ParseRequest(context.Background(), "",
url.Values{})
if err != nil {
t.Errorf("Expected success got '%s'", err.Error())
}

b, _ := getMockedBackend(lm, lconfig, &logger)
_, err := b.GetUsers(context.Background(), url.Values{})
_, err = b.GetUsers(context.Background(), odataReqDefault)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

lm = &mocks.Client{}
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err := b.GetUsers(context.Background(), url.Values{})
g, err := b.GetUsers(context.Background(), odataReqDefault)
if err != nil {
t.Errorf("Expected success, got '%s'", err.Error())
} else if g == nil || len(g) != 0 {
Expand Down
30 changes: 16 additions & 14 deletions services/graph/pkg/identity/mocks/backend.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions services/graph/pkg/identity/odata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package identity

import "github.com/CiscoM31/godata"

// GetExpandValues extracts the values of the $expand query parameter and
// returns them in a []string, rejects any $expand value that consists of more
// than just a single path segment
func GetExpandValues(req *godata.GoDataQuery) ([]string, error) {
if req == nil || req.Expand == nil {
return []string{}, nil
}
expand := make([]string, 0, len(req.Expand.ExpandItems))
for _, item := range req.Expand.ExpandItems {
if item.Filter != nil || item.At != nil || item.Search != nil ||
item.OrderBy != nil || item.Skip != nil || item.Top != nil ||
item.Select != nil || item.Compute != nil || item.Expand != nil ||
item.Levels != 0 {
return []string{}, godata.NotImplementedError("options for $expand not supported")
}
if len(item.Path) > 1 {
return []string{}, godata.NotImplementedError("multiple segments in $expand not supported")
}
expand = append(expand, item.Path[0].Value)
}
return expand, nil
}

// GetSearchValues extracts the value of the $search query parameter and returns
// it as a string. Rejects any search query that is more than just a simple string
func GetSearchValues(req *godata.GoDataQuery) (string, error) {
if req == nil || req.Search == nil {
return "", nil
}

// Only allow simple search queries for now
if len(req.Search.Tree.Children) != 0 {
return "", godata.NotImplementedError("complex search queries are not supported")
}

return req.Search.Tree.Token.Value, nil
}
Loading

0 comments on commit 26f7523

Please sign in to comment.