Skip to content

Commit

Permalink
Avoid panics when LDAP users miss required attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
rhafer committed Aug 4, 2022
1 parent 837cc3f commit 815d355
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
46 changes: 33 additions & 13 deletions services/graph/pkg/identity/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,12 @@ func (i *LDAP) GetUser(ctx context.Context, nameOrID string, queryParam url.Valu
if err != nil {
return nil, err
}
u := i.createUserModelFromLDAP(e)
if u == nil {
return nil, errNotFound
}
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
u := i.createUserModelFromLDAP(e)
if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") {
userGroups, err := i.getGroupsForUser(e.DN)
if err != nil {
Expand Down Expand Up @@ -455,6 +458,10 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra
sel := strings.Split(queryParam.Get("$select"), ",")
exp := strings.Split(queryParam.Get("$expand"), ",")
u := i.createUserModelFromLDAP(e)
// Skip invalid LDAP users
if u == nil {
continue
}
if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") {
userGroups, err := i.getGroupsForUser(e.DN)
if err != nil {
Expand Down Expand Up @@ -505,8 +512,10 @@ func (i *LDAP) GetGroup(ctx context.Context, nameOrID string, queryParam url.Val
}
if len(members) > 0 {
m := make([]libregraph.User, 0, len(members))
for _, u := range members {
m = append(m, *i.createUserModelFromLDAP(u))
for _, ue := range members {
if u := i.createUserModelFromLDAP(ue); u != nil {
m = append(m, *u)
}
}
g.Members = m
}
Expand Down Expand Up @@ -688,8 +697,10 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr
}
if len(members) > 0 {
m := make([]libregraph.User, 0, len(members))
for _, u := range members {
m = append(m, *i.createUserModelFromLDAP(u))
for _, ue := range members {
if u := i.createUserModelFromLDAP(ue); u != nil {
m = append(m, *u)
}
}
g.Members = m
}
Expand All @@ -706,14 +717,15 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra
return nil, err
}

result := []*libregraph.User{}

memberEntries, err := i.expandLDAPGroupMembers(ctx, e)
result := make([]*libregraph.User, 0, len(memberEntries))
if err != nil {
return nil, err
}
for _, member := range memberEntries {
result = append(result, i.createUserModelFromLDAP(member))
if u := i.createUserModelFromLDAP(member); u != nil {
result = append(result, u)
}
}

return result, nil
Expand Down Expand Up @@ -909,12 +921,20 @@ func (i *LDAP) createUserModelFromLDAP(e *ldap.Entry) *libregraph.User {
if e == nil {
return nil
}
return &libregraph.User{
DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)),
Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)),
OnPremisesSamAccountName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.userName)),
Id: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.id)),

opsan := e.GetEqualFoldAttributeValue(i.userAttributeMap.userName)
id := e.GetEqualFoldAttributeValue(i.userAttributeMap.id)

if id != "" && opsan != "" {
return &libregraph.User{
DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)),
Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)),
OnPremisesSamAccountName: &opsan,
Id: &id,
}
}
i.logger.Warn().Str("dn", e.DN).Msg("Invalid User. Missing username or id attribute")
return nil
}

func (i *LDAP) createGroupModelFromLDAP(e *ldap.Entry) *libregraph.Group {
Expand Down
43 changes: 42 additions & 1 deletion services/graph/pkg/identity/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,20 @@ var userEntry = ldap.NewEntry("uid=user",
"mail": {"user@example"},
"entryuuid": {"abcd-defg"},
})
var invalidUserEntry = ldap.NewEntry("uid=user",
map[string][]string{
"uid": {"invalid"},
"displayname": {"DisplayName"},
"mail": {"user@example"},
})
var groupEntry = ldap.NewEntry("cn=group",
map[string][]string{
"cn": {"group"},
"entryuuid": {"abcd-defg"},
"member": {"uid=user,ou=people,dc=test"},
"member": {
"uid=user,ou=people,dc=test",
"uid=invalid,ou=people,dc=test",
},
})
var invalidGroupEntry = ldap.NewEntry("cn=invalid",
map[string][]string{
Expand Down Expand Up @@ -196,6 +205,21 @@ func TestGetUser(t *testing.T) {
} else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) {
t.Errorf("Expected GetUser to return a valid user")
}

// Mock invalid Search Result
lm = &mocks.Client{}
lm.On("Search", mock.Anything).
Return(
&ldap.SearchResult{
Entries: []*ldap.Entry{invalidUserEntry},
},
nil)

b, _ = getMockedBackend(lm, lconfig, &logger)
u, err = b.GetUser(context.Background(), "invalid", nil)
if err == nil || err.Error() != "itemNotFound" {
t.Errorf("Expected 'itemNotFound' got '%s'", err.Error())
}
}

func TestGetUsers(t *testing.T) {
Expand Down Expand Up @@ -298,9 +322,17 @@ func TestGetGroup(t *testing.T) {
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}
sr3 := &ldap.SearchRequest{
BaseDN: "uid=invalid,ou=people,dc=test",
SizeLimit: 1,
Filter: "(objectclass=*)",
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}

lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil)
lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil)
lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err = b.GetGroup(context.Background(), "group", nil)
if err != nil {
Expand Down Expand Up @@ -385,9 +417,18 @@ func TestGetGroups(t *testing.T) {
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}
sr3 := &ldap.SearchRequest{
BaseDN: "uid=invalid,ou=people,dc=test",
SizeLimit: 1,
Filter: "(objectclass=*)",
Attributes: []string{"displayname", "entryUUID", "mail", "uid"},
Controls: []ldap.Control(nil),
}

for _, param := range []url.Values{queryParamSelect, queryParamExpand} {
lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil)
lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil)
lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil)
b, _ = getMockedBackend(lm, lconfig, &logger)
g, err = b.GetGroups(context.Background(), param)
switch {
Expand Down

0 comments on commit 815d355

Please sign in to comment.