From 837cc3fd676e9d90bcf6afca1c3d6626b8272678 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 4 Aug 2022 10:05:34 +0200 Subject: [PATCH] Avoid unneeded extra LDAP query per group When expanding members of an LDAP group we did two group lookup per Group. This can be avoided by expanding the members right from the Group entry of the first query. This also add some more unit test coverage, especially to the expand/select group member test cases. --- services/graph/pkg/identity/ldap.go | 50 ++++++-- services/graph/pkg/identity/ldap_test.go | 140 +++++++++++++++++------ 2 files changed, 140 insertions(+), 50 deletions(-) diff --git a/services/graph/pkg/identity/ldap.go b/services/graph/pkg/identity/ldap.go index 2dc3f23d3c9..eaab3946c45 100644 --- a/services/graph/pkg/identity/ldap.go +++ b/services/graph/pkg/identity/ldap.go @@ -499,14 +499,14 @@ func (i *LDAP) GetGroup(ctx context.Context, nameOrID string, queryParam url.Val return nil, errorcode.New(errorcode.ItemNotFound, "not found") } if slices.Contains(sel, "members") || slices.Contains(exp, "members") { - members, err := i.GetGroupMembers(ctx, *g.Id) + members, err := i.expandLDAPGroupMembers(ctx, e) if err != nil { return nil, err } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) for _, u := range members { - m = append(m, *u) + m = append(m, *i.createUserModelFromLDAP(u)) } g.Members = m } @@ -629,6 +629,14 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr if search == "" { search = queryParam.Get("$search") } + + var expandMembers bool + sel := strings.Split(queryParam.Get("$select"), ",") + exp := strings.Split(queryParam.Get("$expand"), ",") + if slices.Contains(sel, "members") || slices.Contains(exp, "members") { + expandMembers = true + } + var groupFilter string if search != "" { search = ldap.EscapeFilter(search) @@ -639,13 +647,19 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr ) } groupFilter = fmt.Sprintf("(&%s(objectClass=%s)%s)", i.groupFilter, i.groupObjectClass, groupFilter) + + groupAttrs := []string{ + i.groupAttributeMap.name, + i.groupAttributeMap.id, + } + if expandMembers { + groupAttrs = append(groupAttrs, i.groupAttributeMap.member) + } + searchRequest := ldap.NewSearchRequest( i.groupBaseDN, i.groupScope, ldap.NeverDerefAliases, 0, 0, false, groupFilter, - []string{ - i.groupAttributeMap.name, - i.groupAttributeMap.id, - }, + groupAttrs, nil, ) i.logger.Debug().Str("backend", "ldap"). @@ -664,20 +678,18 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr var g *libregraph.Group for _, e := range res.Entries { - sel := strings.Split(queryParam.Get("$select"), ",") - exp := strings.Split(queryParam.Get("$expand"), ",") if g = i.createGroupModelFromLDAP(e); g == nil { continue } - if slices.Contains(sel, "members") || slices.Contains(exp, "members") { - members, err := i.GetGroupMembers(ctx, *g.Id) + if expandMembers { + members, err := i.expandLDAPGroupMembers(ctx, e) if err != nil { return nil, err } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) for _, u := range members { - m = append(m, *u) + m = append(m, *i.createUserModelFromLDAP(u)) } g.Members = m } @@ -696,6 +708,20 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra result := []*libregraph.User{} + memberEntries, err := i.expandLDAPGroupMembers(ctx, e) + if err != nil { + return nil, err + } + for _, member := range memberEntries { + result = append(result, i.createUserModelFromLDAP(member)) + } + + return result, nil +} + +func (i *LDAP) expandLDAPGroupMembers(ctx context.Context, e *ldap.Entry) ([]*ldap.Entry, error) { + result := []*ldap.Entry{} + for _, memberDN := range e.GetEqualFoldAttributeValues(i.groupAttributeMap.member) { if memberDN == "" { continue @@ -707,7 +733,7 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra i.logger.Warn().Err(err).Str("member", memberDN).Msg("error reading group member") continue } - result = append(result, i.createUserModelFromLDAP(ue)) + result = append(result, ue) } return result, nil diff --git a/services/graph/pkg/identity/ldap_test.go b/services/graph/pkg/identity/ldap_test.go index aa3fc790a89..7f561e78656 100644 --- a/services/graph/pkg/identity/ldap_test.go +++ b/services/graph/pkg/identity/ldap_test.go @@ -18,17 +18,19 @@ func getMockedBackend(l ldap.Client, lc config.LDAP, logger *log.Logger) (*LDAP, } var lconfig = config.LDAP{ - UserBaseDN: "dc=test", + UserBaseDN: "ou=people,dc=test", + UserObjectClass: "inetOrgPerson", UserSearchScope: "sub", - UserFilter: "filter", + UserFilter: "", UserDisplayNameAttribute: "displayname", UserIDAttribute: "entryUUID", UserEmailAttribute: "mail", UserNameAttribute: "uid", - GroupBaseDN: "dc=test", + GroupBaseDN: "ou=groups,dc=test", + GroupObjectClass: "groupOfNames", GroupSearchScope: "sub", - GroupFilter: "filter", + GroupFilter: "", GroupNameAttribute: "cn", GroupIDAttribute: "entryUUID", } @@ -44,6 +46,7 @@ var groupEntry = ldap.NewEntry("cn=group", map[string][]string{ "cn": {"group"}, "entryuuid": {"abcd-defg"}, + "member": {"uid=user,ou=people,dc=test"}, }) var invalidGroupEntry = ldap.NewEntry("cn=invalid", map[string][]string{ @@ -116,8 +119,7 @@ func TestCreateUserModelFromLDAP(t *testing.T) { func TestGetUser(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) @@ -145,8 +147,7 @@ func TestGetUser(t *testing.T) { // Mock an empty Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( &ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) @@ -165,10 +166,9 @@ func TestGetUser(t *testing.T) { t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) } - // Mock a valid Search Result + // Mock a valid Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( &ldap.SearchResult{ Entries: []*ldap.Entry{userEntry}, @@ -201,9 +201,7 @@ func TestGetUser(t *testing.T) { func TestGetUsers(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetUsers(context.Background(), url.Values{}) @@ -212,9 +210,7 @@ func TestGetUsers(t *testing.T) { } lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetUsers(context.Background(), url.Values{}) if err != nil { @@ -227,15 +223,13 @@ func TestGetUsers(t *testing.T) { func TestGetGroup(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) queryParamExpand := url.Values{ - "$expand": []string{"memberOf"}, + "$expand": []string{"members"}, } queryParamSelect := url.Values{ - "$select": []string{"memberOf"}, + "$select": []string{"members"}, } b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroup(context.Background(), "group", nil) @@ -253,9 +247,7 @@ func TestGetGroup(t *testing.T) { // Mock an empty Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) _, err = b.GetGroup(context.Background(), "group", nil) if err == nil || err.Error() != "itemNotFound" { @@ -291,9 +283,24 @@ func TestGetGroup(t *testing.T) { // Mock a valid Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything).Return(&ldap.SearchResult{ - Entries: []*ldap.Entry{groupEntry}, - }, nil) + sr1 := &ldap.SearchRequest{ + BaseDN: "ou=groups,dc=test", + Scope: 2, + SizeLimit: 1, + Filter: "(&(objectClass=groupOfNames)(|(cn=group)(entryUUID=group)))", + Attributes: []string{"cn", "entryUUID", "member"}, + Controls: []ldap.Control(nil), + } + sr2 := &ldap.SearchRequest{ + BaseDN: "uid=user,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + + lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) + lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err = b.GetGroup(context.Background(), "group", nil) if err != nil { @@ -302,24 +309,38 @@ func TestGetGroup(t *testing.T) { t.Errorf("Expected GetGroup to return a valid group") } g, err = b.GetGroup(context.Background(), "group", queryParamExpand) - if err != nil { + switch { + case err != nil: t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) - } else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): t.Errorf("Expected GetGroup to return a valid group") + case len(g.Members) != 1: + t.Errorf("Expected GetGroup with expand to return one member") + case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup with expand to return correct member") } g, err = b.GetGroup(context.Background(), "group", queryParamSelect) - if err != nil { + switch { + case err != nil: t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) - } else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): t.Errorf("Expected GetGroup to return a valid group") + case len(g.Members) != 1: + t.Errorf("Expected GetGroup with expand to return one member") + case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup with expand to return correct member") } } func TestGetGroups(t *testing.T) { + queryParamExpand := url.Values{ + "$expand": []string{"members"}, + } + queryParamSelect := url.Values{ + "$select": []string{"members"}, + } lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroups(context.Background(), url.Values{}) if err == nil || err.Error() != "itemNotFound" { @@ -327,9 +348,7 @@ func TestGetGroups(t *testing.T) { } lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetGroups(context.Background(), url.Values{}) if err != nil { @@ -337,4 +356,49 @@ func TestGetGroups(t *testing.T) { } else if g == nil || len(g) != 0 { t.Errorf("Expected zero length user slice") } + + lm = &mocks.Client{} + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{groupEntry}, + }, nil) + b, _ = getMockedBackend(lm, lconfig, &logger) + g, err = b.GetGroups(context.Background(), url.Values{}) + if err != nil { + t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) + } else if *g[0].Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + t.Errorf("Expected GetGroup to return a valid group") + } + + // Mock a valid Search Result with expanded group members + lm = &mocks.Client{} + sr1 := &ldap.SearchRequest{ + BaseDN: "ou=groups,dc=test", + Scope: 2, + Filter: "(&(objectClass=groupOfNames))", + Attributes: []string{"cn", "entryUUID", "member"}, + Controls: []ldap.Control(nil), + } + sr2 := &ldap.SearchRequest{ + BaseDN: "uid=user,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + for _, param := range []url.Values{queryParamSelect, queryParamExpand} { + lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) + lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) + b, _ = getMockedBackend(lm, lconfig, &logger) + g, err = b.GetGroups(context.Background(), param) + switch { + case err != nil: + t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) + case g[0].GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): + t.Errorf("Expected GetGroup to return a valid group") + case len(g[0].Members) != 1: + t.Errorf("Expected GetGroup to return group with one member") + case g[0].Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup to return group with correct member") + } + } }