diff --git a/changelog/unreleased/fix-ldap-invalid-users-groups.md b/changelog/unreleased/fix-ldap-invalid-users-groups.md new file mode 100644 index 00000000000..9f0db8640a8 --- /dev/null +++ b/changelog/unreleased/fix-ldap-invalid-users-groups.md @@ -0,0 +1,6 @@ +Bugfix: Fix handling of invalid LDAP users and groups + +We fixed an issue where ocis would exit with a panic when LDAP users +or groups where missing required attributes (e.g. the id) + +https://github.com/owncloud/ocis/issues/4274 diff --git a/services/graph/pkg/identity/ldap.go b/services/graph/pkg/identity/ldap.go index 398fabaca63..24f8678c647 100644 --- a/services/graph/pkg/identity/ldap.go +++ b/services/graph/pkg/identity/ldap.go @@ -315,7 +315,13 @@ func (i *LDAP) getEntryByDN(dn string, attrs []string) (*ldap.Entry, error) { nil, ) - i.logger.Debug().Str("backend", "ldap").Str("dn", dn).Msg("Search user by DN") + i.logger.Debug().Str("backend", "ldap"). + Str("base", searchRequest.BaseDN). + Str("filter", searchRequest.Filter). + Int("scope", searchRequest.Scope). + Int("sizelimit", searchRequest.SizeLimit). + Interface("attributes", searchRequest.Attributes). + Msg("getEntryByDN") res, err := i.conn.Search(searchRequest) if err != nil { @@ -353,7 +359,13 @@ func (i *LDAP) getLDAPUserByFilter(filter string) (*ldap.Entry, error) { }, nil, ) - i.logger.Debug().Str("backend", "ldap").Msgf("Search %s", i.userBaseDN) + i.logger.Debug().Str("backend", "ldap"). + Str("base", searchRequest.BaseDN). + Str("filter", searchRequest.Filter). + Int("scope", searchRequest.Scope). + Int("sizelimit", searchRequest.SizeLimit). + Interface("attributes", searchRequest.Attributes). + Msg("getLDAPUserByFilter") res, err := i.conn.Search(searchRequest) if err != nil { @@ -380,19 +392,19 @@ func (i *LDAP) GetUser(ctx context.Context, nameOrID string, queryParam url.Valu if err != nil { return nil, err } + u := i.createUserModelFromLDAP(e) + if u == nil { + return nil, errNotFound + } sel := strings.Split(queryParam.Get("$select"), ",") exp := strings.Split(queryParam.Get("$expand"), ",") - u := i.createUserModelFromLDAP(e) if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") { userGroups, err := i.getGroupsForUser(e.DN) if err != nil { return nil, err } if len(userGroups) > 0 { - groups := make([]libregraph.Group, 0, len(userGroups)) - for _, g := range userGroups { - groups = append(groups, *i.createGroupModelFromLDAP(g)) - } + groups := i.groupsFromLDAPEntries(userGroups) u.MemberOf = groups } } @@ -428,7 +440,13 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra }, nil, ) - i.logger.Debug().Str("backend", "ldap").Msgf("Search %s", i.userBaseDN) + i.logger.Debug().Str("backend", "ldap"). + Str("base", searchRequest.BaseDN). + Str("filter", searchRequest.Filter). + Int("scope", searchRequest.Scope). + Int("sizelimit", searchRequest.SizeLimit). + Interface("attributes", searchRequest.Attributes). + Msg("GetUsers") res, err := i.conn.Search(searchRequest) if err != nil { return nil, errorcode.New(errorcode.ItemNotFound, err.Error()) @@ -440,6 +458,10 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra sel := strings.Split(queryParam.Get("$select"), ",") exp := strings.Split(queryParam.Get("$expand"), ",") u := i.createUserModelFromLDAP(e) + // Skip invalid LDAP users + if u == nil { + continue + } if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") { userGroups, err := i.getGroupsForUser(e.DN) if err != nil { @@ -450,10 +472,7 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra if expand == "" { expand = "false" } - groups := make([]libregraph.Group, 0, len(userGroups)) - for _, g := range userGroups { - groups = append(groups, *i.createGroupModelFromLDAP(g)) - } + groups := i.groupsFromLDAPEntries(userGroups) u.MemberOf = groups } } @@ -482,16 +501,21 @@ func (i *LDAP) GetGroup(ctx context.Context, nameOrID string, queryParam url.Val } sel := strings.Split(queryParam.Get("$select"), ",") exp := strings.Split(queryParam.Get("$expand"), ",") - g := i.createGroupModelFromLDAP(e) + var g *libregraph.Group + if g = i.createGroupModelFromLDAP(e); g == nil { + return nil, errorcode.New(errorcode.ItemNotFound, "not found") + } if slices.Contains(sel, "members") || slices.Contains(exp, "members") { - members, err := i.GetGroupMembers(ctx, *g.Id) + members, err := i.expandLDAPGroupMembers(ctx, e) if err != nil { return nil, err } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) - for _, u := range members { - m = append(m, *u) + for _, ue := range members { + if u := i.createUserModelFromLDAP(ue); u != nil { + m = append(m, *u) + } } g.Members = m } @@ -546,7 +570,13 @@ func (i *LDAP) getLDAPGroupsByFilter(filter string, requestMembers, single bool) attrs, nil, ) - i.logger.Debug().Str("backend", "ldap").Msgf("Search %s", i.groupBaseDN) + i.logger.Debug().Str("backend", "ldap"). + Str("base", searchRequest.BaseDN). + Str("filter", searchRequest.Filter). + Int("scope", searchRequest.Scope). + Int("sizelimit", searchRequest.SizeLimit). + Interface("attributes", searchRequest.Attributes). + Msg("getLDAPGroupsByFilter") res, err := i.conn.Search(searchRequest) if err != nil { @@ -608,6 +638,14 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr if search == "" { search = queryParam.Get("$search") } + + var expandMembers bool + sel := strings.Split(queryParam.Get("$select"), ",") + exp := strings.Split(queryParam.Get("$expand"), ",") + if slices.Contains(sel, "members") || slices.Contains(exp, "members") { + expandMembers = true + } + var groupFilter string if search != "" { search = ldap.EscapeFilter(search) @@ -618,16 +656,28 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr ) } groupFilter = fmt.Sprintf("(&%s(objectClass=%s)%s)", i.groupFilter, i.groupObjectClass, groupFilter) + + groupAttrs := []string{ + i.groupAttributeMap.name, + i.groupAttributeMap.id, + } + if expandMembers { + groupAttrs = append(groupAttrs, i.groupAttributeMap.member) + } + searchRequest := ldap.NewSearchRequest( i.groupBaseDN, i.groupScope, ldap.NeverDerefAliases, 0, 0, false, groupFilter, - []string{ - i.groupAttributeMap.name, - i.groupAttributeMap.id, - }, + groupAttrs, nil, ) - i.logger.Debug().Str("backend", "ldap").Str("Base", i.groupBaseDN).Str("filter", groupFilter).Msg("ldap search") + i.logger.Debug().Str("backend", "ldap"). + Str("base", searchRequest.BaseDN). + Str("filter", searchRequest.Filter). + Int("scope", searchRequest.Scope). + Int("sizelimit", searchRequest.SizeLimit). + Interface("attributes", searchRequest.Attributes). + Msg("GetGroups") res, err := i.conn.Search(searchRequest) if err != nil { return nil, errorcode.New(errorcode.ItemNotFound, err.Error()) @@ -635,19 +685,22 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr groups := make([]*libregraph.Group, 0, len(res.Entries)) + var g *libregraph.Group for _, e := range res.Entries { - sel := strings.Split(queryParam.Get("$select"), ",") - exp := strings.Split(queryParam.Get("$expand"), ",") - g := i.createGroupModelFromLDAP(e) - if slices.Contains(sel, "members") || slices.Contains(exp, "members") { - members, err := i.GetGroupMembers(ctx, *g.Id) + if g = i.createGroupModelFromLDAP(e); g == nil { + continue + } + if expandMembers { + members, err := i.expandLDAPGroupMembers(ctx, e) if err != nil { return nil, err } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) - for _, u := range members { - m = append(m, *u) + for _, ue := range members { + if u := i.createUserModelFromLDAP(ue); u != nil { + m = append(m, *u) + } } g.Members = m } @@ -664,7 +717,22 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra return nil, err } - result := []*libregraph.User{} + memberEntries, err := i.expandLDAPGroupMembers(ctx, e) + result := make([]*libregraph.User, 0, len(memberEntries)) + if err != nil { + return nil, err + } + for _, member := range memberEntries { + if u := i.createUserModelFromLDAP(member); u != nil { + result = append(result, u) + } + } + + return result, nil +} + +func (i *LDAP) expandLDAPGroupMembers(ctx context.Context, e *ldap.Entry) ([]*ldap.Entry, error) { + result := []*ldap.Entry{} for _, memberDN := range e.GetEqualFoldAttributeValues(i.groupAttributeMap.member) { if memberDN == "" { @@ -677,7 +745,7 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra i.logger.Warn().Err(err).Str("member", memberDN).Msg("error reading group member") continue } - result = append(result, i.createUserModelFromLDAP(ue)) + result = append(result, ue) } return result, nil @@ -853,19 +921,44 @@ func (i *LDAP) createUserModelFromLDAP(e *ldap.Entry) *libregraph.User { if e == nil { return nil } - return &libregraph.User{ - DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)), - Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)), - OnPremisesSamAccountName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.userName)), - Id: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.id)), + + opsan := e.GetEqualFoldAttributeValue(i.userAttributeMap.userName) + id := e.GetEqualFoldAttributeValue(i.userAttributeMap.id) + + if id != "" && opsan != "" { + return &libregraph.User{ + DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)), + Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)), + OnPremisesSamAccountName: &opsan, + Id: &id, + } } + i.logger.Warn().Str("dn", e.DN).Msg("Invalid User. Missing username or id attribute") + return nil } func (i *LDAP) createGroupModelFromLDAP(e *ldap.Entry) *libregraph.Group { - return &libregraph.Group{ - DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.groupAttributeMap.name)), - Id: pointerOrNil(e.GetEqualFoldAttributeValue(i.groupAttributeMap.id)), + name := e.GetEqualFoldAttributeValue(i.groupAttributeMap.name) + id := e.GetEqualFoldAttributeValue(i.groupAttributeMap.id) + + if id != "" && name != "" { + return &libregraph.Group{ + DisplayName: &name, + Id: &id, + } + } + i.logger.Warn().Str("dn", e.DN).Msg("Group is missing name or id") + return nil +} + +func (i *LDAP) groupsFromLDAPEntries(e []*ldap.Entry) []libregraph.Group { + groups := make([]libregraph.Group, 0, len(e)) + for _, g := range e { + if grp := i.createGroupModelFromLDAP(g); grp != nil { + groups = append(groups, *grp) + } } + return groups } func pointerOrNil(val string) *string { diff --git a/services/graph/pkg/identity/ldap_test.go b/services/graph/pkg/identity/ldap_test.go index 7db012edbf4..b442a906542 100644 --- a/services/graph/pkg/identity/ldap_test.go +++ b/services/graph/pkg/identity/ldap_test.go @@ -18,17 +18,19 @@ func getMockedBackend(l ldap.Client, lc config.LDAP, logger *log.Logger) (*LDAP, } var lconfig = config.LDAP{ - UserBaseDN: "dc=test", + UserBaseDN: "ou=people,dc=test", + UserObjectClass: "inetOrgPerson", UserSearchScope: "sub", - UserFilter: "filter", + UserFilter: "", UserDisplayNameAttribute: "displayname", UserIDAttribute: "entryUUID", UserEmailAttribute: "mail", UserNameAttribute: "uid", - GroupBaseDN: "dc=test", + GroupBaseDN: "ou=groups,dc=test", + GroupObjectClass: "groupOfNames", GroupSearchScope: "sub", - GroupFilter: "filter", + GroupFilter: "", GroupNameAttribute: "cn", GroupIDAttribute: "entryUUID", } @@ -40,10 +42,24 @@ var userEntry = ldap.NewEntry("uid=user", "mail": {"user@example"}, "entryuuid": {"abcd-defg"}, }) +var invalidUserEntry = ldap.NewEntry("uid=user", + map[string][]string{ + "uid": {"invalid"}, + "displayname": {"DisplayName"}, + "mail": {"user@example"}, + }) var groupEntry = ldap.NewEntry("cn=group", map[string][]string{ "cn": {"group"}, "entryuuid": {"abcd-defg"}, + "member": { + "uid=user,ou=people,dc=test", + "uid=invalid,ou=people,dc=test", + }, + }) +var invalidGroupEntry = ldap.NewEntry("cn=invalid", + map[string][]string{ + "cn": {"invalid"}, }) var logger = log.NewLogger(log.Level("debug")) @@ -112,8 +128,7 @@ func TestCreateUserModelFromLDAP(t *testing.T) { func TestGetUser(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) @@ -141,8 +156,7 @@ func TestGetUser(t *testing.T) { // Mock an empty Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( &ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) @@ -161,10 +175,9 @@ func TestGetUser(t *testing.T) { t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) } - // Mock a valid Search Result + // Mock a valid Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + lm.On("Search", mock.Anything). Return( &ldap.SearchResult{ Entries: []*ldap.Entry{userEntry}, @@ -192,14 +205,27 @@ func TestGetUser(t *testing.T) { } else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) { t.Errorf("Expected GetUser to return a valid user") } + + // Mock invalid Search Result + lm = &mocks.Client{} + lm.On("Search", mock.Anything). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{invalidUserEntry}, + }, + nil) + + b, _ = getMockedBackend(lm, lconfig, &logger) + u, err = b.GetUser(context.Background(), "invalid", nil) + if err == nil || err.Error() != "itemNotFound" { + t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) + } } func TestGetUsers(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetUsers(context.Background(), url.Values{}) @@ -208,9 +234,7 @@ func TestGetUsers(t *testing.T) { } lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetUsers(context.Background(), url.Values{}) if err != nil { @@ -223,15 +247,13 @@ func TestGetUsers(t *testing.T) { func TestGetGroup(t *testing.T) { // Mock a Sizelimit Error lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultSizeLimitExceeded, errors.New("mock"))) queryParamExpand := url.Values{ - "$expand": []string{"memberOf"}, + "$expand": []string{"members"}, } queryParamSelect := url.Values{ - "$select": []string{"memberOf"}, + "$select": []string{"members"}, } b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroup(context.Background(), "group", nil) @@ -249,9 +271,7 @@ func TestGetGroup(t *testing.T) { // Mock an empty Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) _, err = b.GetGroup(context.Background(), "group", nil) if err == nil || err.Error() != "itemNotFound" { @@ -266,39 +286,93 @@ func TestGetGroup(t *testing.T) { t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) } - // Mock a valid Search Result + // Mock an invalid Search Result lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{ - Entries: []*ldap.Entry{groupEntry}, - }, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{invalidGroupEntry}, + }, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetGroup(context.Background(), "group", nil) + if err == nil || err.Error() != "itemNotFound" { + t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) + } + g, err = b.GetGroup(context.Background(), "group", queryParamExpand) + if err == nil || err.Error() != "itemNotFound" { + t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) + } + g, err = b.GetGroup(context.Background(), "group", queryParamSelect) + if err == nil || err.Error() != "itemNotFound" { + t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) + } + + // Mock a valid Search Result + lm = &mocks.Client{} + sr1 := &ldap.SearchRequest{ + BaseDN: "ou=groups,dc=test", + Scope: 2, + SizeLimit: 1, + Filter: "(&(objectClass=groupOfNames)(|(cn=group)(entryUUID=group)))", + Attributes: []string{"cn", "entryUUID", "member"}, + Controls: []ldap.Control(nil), + } + sr2 := &ldap.SearchRequest{ + BaseDN: "uid=user,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + sr3 := &ldap.SearchRequest{ + BaseDN: "uid=invalid,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + + lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) + lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) + lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil) + b, _ = getMockedBackend(lm, lconfig, &logger) + g, err = b.GetGroup(context.Background(), "group", nil) if err != nil { t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) } else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { t.Errorf("Expected GetGroup to return a valid group") } g, err = b.GetGroup(context.Background(), "group", queryParamExpand) - if err != nil { + switch { + case err != nil: t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) - } else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): t.Errorf("Expected GetGroup to return a valid group") + case len(g.Members) != 1: + t.Errorf("Expected GetGroup with expand to return one member") + case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup with expand to return correct member") } g, err = b.GetGroup(context.Background(), "group", queryParamSelect) - if err != nil { + switch { + case err != nil: t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) - } else if *g.Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + case g.GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): t.Errorf("Expected GetGroup to return a valid group") + case len(g.Members) != 1: + t.Errorf("Expected GetGroup with expand to return one member") + case g.Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup with expand to return correct member") } } func TestGetGroups(t *testing.T) { + queryParamExpand := url.Values{ + "$expand": []string{"members"}, + } + queryParamSelect := url.Values{ + "$select": []string{"members"}, + } lm := &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) + lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroups(context.Background(), url.Values{}) if err == nil || err.Error() != "itemNotFound" { @@ -306,9 +380,7 @@ func TestGetGroups(t *testing.T) { } lm = &mocks.Client{} - lm.On("Search", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(&ldap.SearchResult{}, nil) + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetGroups(context.Background(), url.Values{}) if err != nil { @@ -316,4 +388,58 @@ func TestGetGroups(t *testing.T) { } else if g == nil || len(g) != 0 { t.Errorf("Expected zero length user slice") } + + lm = &mocks.Client{} + lm.On("Search", mock.Anything).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{groupEntry}, + }, nil) + b, _ = getMockedBackend(lm, lconfig, &logger) + g, err = b.GetGroups(context.Background(), url.Values{}) + if err != nil { + t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) + } else if *g[0].Id != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id) { + t.Errorf("Expected GetGroup to return a valid group") + } + + // Mock a valid Search Result with expanded group members + lm = &mocks.Client{} + sr1 := &ldap.SearchRequest{ + BaseDN: "ou=groups,dc=test", + Scope: 2, + Filter: "(&(objectClass=groupOfNames))", + Attributes: []string{"cn", "entryUUID", "member"}, + Controls: []ldap.Control(nil), + } + sr2 := &ldap.SearchRequest{ + BaseDN: "uid=user,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + sr3 := &ldap.SearchRequest{ + BaseDN: "uid=invalid,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + + for _, param := range []url.Values{queryParamSelect, queryParamExpand} { + lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) + lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) + lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil) + b, _ = getMockedBackend(lm, lconfig, &logger) + g, err = b.GetGroups(context.Background(), param) + switch { + case err != nil: + t.Errorf("Expected GetGroup to succeed. Got %s", err.Error()) + case g[0].GetId() != groupEntry.GetEqualFoldAttributeValue(b.groupAttributeMap.id): + t.Errorf("Expected GetGroup to return a valid group") + case len(g[0].Members) != 1: + t.Errorf("Expected GetGroup to return group with one member") + case g[0].Members[0].GetId() != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id): + t.Errorf("Expected GetGroup to return group with correct member") + } + } }