Skip to content

Commit

Permalink
Avoid unneeded extra LDAP query per group
Browse files Browse the repository at this point in the history
When expanding members of an LDAP group we did two group lookup per
Group. This can be avoided by expanding the members right from the
Group entry of the first query.

This also add some more unit test coverage, especially to the expand/select
group member test cases.
  • Loading branch information
rhafer committed Aug 4, 2022
1 parent 7e28445 commit 837cc3f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 50 deletions.
50 changes: 38 additions & 12 deletions services/graph/pkg/identity/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,14 +499,14 @@ func (i *LDAP) GetGroup(ctx context.Context, nameOrID string, queryParam url.Val
return nil, errorcode.New(errorcode.ItemNotFound, "not found")
}
if slices.Contains(sel, "members") || slices.Contains(exp, "members") {
members, err := i.GetGroupMembers(ctx, *g.Id)
members, err := i.expandLDAPGroupMembers(ctx, e)
if err != nil {
return nil, err
}
if len(members) > 0 {
m := make([]libregraph.User, 0, len(members))
for _, u := range members {
m = append(m, *u)
m = append(m, *i.createUserModelFromLDAP(u))
}
g.Members = m
}
Expand Down Expand Up @@ -629,6 +629,14 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr
if search == "" {
search = queryParam.Get("$search")
}

var expandMembers bool
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
if slices.Contains(sel, "members") || slices.Contains(exp, "members") {
expandMembers = true
}

var groupFilter string
if search != "" {
search = ldap.EscapeFilter(search)
Expand All @@ -639,13 +647,19 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr
)
}
groupFilter = fmt.Sprintf("(&%s(objectClass=%s)%s)", i.groupFilter, i.groupObjectClass, groupFilter)

groupAttrs := []string{
i.groupAttributeMap.name,
i.groupAttributeMap.id,
}
if expandMembers {
groupAttrs = append(groupAttrs, i.groupAttributeMap.member)
}

searchRequest := ldap.NewSearchRequest(
i.groupBaseDN, i.groupScope, ldap.NeverDerefAliases, 0, 0, false,
groupFilter,
[]string{
i.groupAttributeMap.name,
i.groupAttributeMap.id,
},
groupAttrs,
nil,
)
i.logger.Debug().Str("backend", "ldap").
Expand All @@ -664,20 +678,18 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr

var g *libregraph.Group
for _, e := range res.Entries {
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
if g = i.createGroupModelFromLDAP(e); g == nil {
continue
}
if slices.Contains(sel, "members") || slices.Contains(exp, "members") {
members, err := i.GetGroupMembers(ctx, *g.Id)
if expandMembers {
members, err := i.expandLDAPGroupMembers(ctx, e)
if err != nil {
return nil, err
}
if len(members) > 0 {
m := make([]libregraph.User, 0, len(members))
for _, u := range members {
m = append(m, *u)
m = append(m, *i.createUserModelFromLDAP(u))
}
g.Members = m
}
Expand All @@ -696,6 +708,20 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra

result := []*libregraph.User{}

memberEntries, err := i.expandLDAPGroupMembers(ctx, e)
if err != nil {
return nil, err
}
for _, member := range memberEntries {
result = append(result, i.createUserModelFromLDAP(member))
}

return result, nil
}

func (i *LDAP) expandLDAPGroupMembers(ctx context.Context, e *ldap.Entry) ([]*ldap.Entry, error) {
result := []*ldap.Entry{}

for _, memberDN := range e.GetEqualFoldAttributeValues(i.groupAttributeMap.member) {
if memberDN == "" {
continue
Expand All @@ -707,7 +733,7 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra
i.logger.Warn().Err(err).Str("member", memberDN).Msg("error reading group member")
continue
}
result = append(result, i.createUserModelFromLDAP(ue))
result = append(result, ue)
}

return result, nil
Expand Down
140 changes: 102 additions & 38 deletions services/graph/pkg/identity/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ func getMockedBackend(l ldap.Client, lc config.LDAP, logger *log.Logger) (*LDAP,
}

var lconfig = config.LDAP{
UserBaseDN: "dc=test",
UserBaseDN: "ou=people,dc=test",
UserObjectClass: "inetOrgPerson",
UserSearchScope: "sub",
UserFilter: "filter",
UserFilter: "",
UserDisplayNameAttribute: "displayname",
UserIDAttribute: "entryUUID",
UserEmailAttribute: "mail",
UserNameAttribute: "uid",

GroupBaseDN: "dc=test",
GroupBaseDN: "ou=groups,dc=test",
GroupObjectClass: "groupOfNames",
GroupSearchScope: "sub",
GroupFilter: "filter",
GroupFilter: "",
GroupNameAttribute: "cn",
GroupIDAttribute: "entryUUID",
}
Expand All @@ -44,6 +46,7 @@ var groupEntry = ldap.NewEntry("cn=group",
map[string][]string{
"cn": {"group"},
"entryuuid": {"abcd-defg"},
"member": {"uid=user,ou=people,dc=test"},
})
var invalidGroupEntry = ldap.NewEntry("cn=invalid",
map[string][]string{
Expand Down Expand Up @@ -116,8 +119,7 @@ func TestCreateUserModelFromLDAP(t *testing.T) {
func TestGetUser(t *testing.T) {
// Mock a Sizelimit Error
lm := &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
lm.On("Search", mock.Anything).
Return(
nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock")))
b, _ := getMockedBackend(lm, lconfig, &logger)
Expand Down Expand Up @@ -145,8 +147,7 @@ func TestGetUser(t *testing.T) {

// Mock an empty Search Result
lm = &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
lm.On("Search", mock.Anything).
Return(
&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
Expand All @@ -165,10 +166,9 @@ func TestGetUser(t *testing.T) {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

// Mock a valid Search Result
// Mock a valid Search Result
lm = &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
lm.On("Search", mock.Anything).
Return(
&ldap.SearchResult{
Entries: []*ldap.Entry{userEntry},
Expand Down Expand Up @@ -201,9 +201,7 @@ func TestGetUser(t *testing.T) {
func TestGetUsers(t *testing.T) {
// Mock a Sizelimit Error
lm := &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock")))
lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock")))

b, _ := getMockedBackend(lm, lconfig, &logger)
_, err := b.GetUsers(context.Background(), url.Values{})
Expand All @@ -212,9 +210,7 @@ func TestGetUsers(t *testing.T) {
}

lm = &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&ldap.SearchResult{}, nil)
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err := b.GetUsers(context.Background(), url.Values{})
if err != nil {
Expand All @@ -227,15 +223,13 @@ func TestGetUsers(t *testing.T) {
func TestGetGroup(t *testing.T) {
// Mock a Sizelimit Error
lm := &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock")))
lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock")))

queryParamExpand := url.Values{
"$expand": []string{"memberOf"},
"$expand": []string{"members"},
}
queryParamSelect := url.Values{
"$select": []string{"memberOf"},
"$select": []string{"members"},
}
b, _ := getMockedBackend(lm, lconfig, &logger)
_, err := b.GetGroup(context.Background(), "group", nil)
Expand All @@ -253,9 +247,7 @@ func TestGetGroup(t *testing.T) {

// Mock an empty Search Result
lm = &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&ldap.SearchResult{}, nil)
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
_, err = b.GetGroup(context.Background(), "group", nil)
if err == nil || err.Error() != "itemNotFound" {
Expand Down Expand Up @@ -291,9 +283,24 @@ func TestGetGroup(t *testing.T) {

// Mock a valid Search Result
lm = &mocks.Client{}
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{groupEntry},
}, nil)
sr1 := &ldap.SearchRequest{
BaseDN: "ou=groups,dc=test",
Scope: 2,
SizeLimit: 1,
Filter: "(&(objectClass=groupOfNames)(|(cn=group)(entryUUID=group)))",
Attributes: []string{"cn", "entryUUID", "member"},
Controls: []ldap.Control(nil),
}
sr2 := &ldap.SearchRequest{
BaseDN: "uid=user,ou=people,dc=test",
SizeLimit: 1,
Filter: "(objectclass=*)",
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}

lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil)
lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err = b.GetGroup(context.Background(), "group", nil)
if err != nil {
Expand All @@ -302,39 +309,96 @@ func TestGetGroup(t *testing.T) {
t.Errorf("Expected GetGroup to return a valid group")
}
g, err = b.GetGroup(context.Background(), "group", queryParamExpand)
if err != nil {
switch {
case err != nil:
t.Errorf("Expected GetGroup to succeed. Got %s", err.Error())
} else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) {
case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id):
t.Errorf("Expected GetGroup to return a valid group")
case len(g.Members) != 1:
t.Errorf("Expected GetGroup with expand to return one member")
case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id):
t.Errorf("Expected GetGroup with expand to return correct member")
}
g, err = b.GetGroup(context.Background(), "group", queryParamSelect)
if err != nil {
switch {
case err != nil:
t.Errorf("Expected GetGroup to succeed. Got %s", err.Error())
} else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) {
case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id):
t.Errorf("Expected GetGroup to return a valid group")
case len(g.Members) != 1:
t.Errorf("Expected GetGroup with expand to return one member")
case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id):
t.Errorf("Expected GetGroup with expand to return correct member")
}
}

func TestGetGroups(t *testing.T) {
queryParamExpand := url.Values{
"$expand": []string{"members"},
}
queryParamSelect := url.Values{
"$select": []string{"members"},
}
lm := &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock")))
lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock")))
b, _ := getMockedBackend(lm, lconfig, &logger)
_, err := b.GetGroups(context.Background(), url.Values{})
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}

lm = &mocks.Client{}
lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&ldap.SearchResult{}, nil)
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err := b.GetGroups(context.Background(), url.Values{})
if err != nil {
t.Errorf("Expected success, got '%s'", err.Error())
} else if g == nil || len(g) != 0 {
t.Errorf("Expected zero length user slice")
}

lm = &mocks.Client{}
lm.On("Search", mock.Anything).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{groupEntry},
}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err = b.GetGroups(context.Background(), url.Values{})
if err != nil {
t.Errorf("Expected GetGroup to succeed. Got %s", err.Error())
} else if *g[0].Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) {
t.Errorf("Expected GetGroup to return a valid group")
}

// Mock a valid Search Result with expanded group members
lm = &mocks.Client{}
sr1 := &ldap.SearchRequest{
BaseDN: "ou=groups,dc=test",
Scope: 2,
Filter: "(&(objectClass=groupOfNames))",
Attributes: []string{"cn", "entryUUID", "member"},
Controls: []ldap.Control(nil),
}
sr2 := &ldap.SearchRequest{
BaseDN: "uid=user,ou=people,dc=test",
SizeLimit: 1,
Filter: "(objectclass=*)",
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}
for _, param := range []url.Values{queryParamSelect, queryParamExpand} {
lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil)
lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err = b.GetGroups(context.Background(), param)
switch {
case err != nil:
t.Errorf("Expected GetGroup to succeed. Got %s", err.Error())
case g[0].GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id):
t.Errorf("Expected GetGroup to return a valid group")
case len(g[0].Members) != 1:
t.Errorf("Expected GetGroup to return group with one member")
case g[0].Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id):
t.Errorf("Expected GetGroup to return group with correct member")
}
}
}

0 comments on commit 837cc3f

Please sign in to comment.