Skip to content

Commit

Permalink
Handle requested claims in user info
Browse files Browse the repository at this point in the history
  • Loading branch information
giftkugel committed Oct 6, 2024
1 parent 765f7ef commit ad5eb5c
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 45 deletions.
14 changes: 8 additions & 6 deletions internal/manager/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ type clientStores struct {
}

type ValidAccessToken struct {
User *config.User
Client *config.Client
Scopes []string
User *config.User
Client *config.Client
Scopes []string
RequestedClaims *oidc.ClaimsParameter
}

type Manager struct {
Expand Down Expand Up @@ -286,9 +287,10 @@ func (tokenManager *Manager) validateAccessToken(accessTokenValue string) (*Vali
}

validAccessToken := &ValidAccessToken{
User: user,
Client: client,
Scopes: accessToken.Scopes,
User: user,
Client: client,
Scopes: accessToken.Scopes,
RequestedClaims: accessToken.RequestedClaims,
}

return validAccessToken, true
Expand Down
41 changes: 37 additions & 4 deletions internal/oidc/claims.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
package oidc

const (
ClaimNonce string = "nonce"
ClaimAuthorizedParty string = "azp"
ClaimAtHash string = "at_hash"
ClaimAuthTime string = "auth_time"
ClaimNonce string = "nonce"
ClaimAuthorizedParty string = "azp"
ClaimAtHash string = "at_hash"
ClaimAuthTime string = "auth_time"
ClaimName string = "name"
ClaimGivenName string = "given_name"
ClaimMiddleName string = "middle_name"
ClaimFamilyName string = "family_name"
ClaimNickname string = "nickname"
ClaimPreferredUsername string = "preferred_username"
ClaimGender string = "gender"
ClaimBirthdate string = "birthdate"
ClaimZoneInfo string = "zoneinfo"
ClaimLocale string = "locale"
ClaimWebsite string = "website"
ClaimProfile string = "profile"
ClaimPicture string = "picture"
ClaimEmail string = "email"
ClaimEmailVerified string = "email_verified"
ClaimPhoneNumber string = "phone_number"
ClaimPhoneNumberVerified string = "phone_number_verified"
ClaimUpdatedAt string = "updated_at"
ClaimAddressFormatted string = "formatted"
ClaimAddressStreetAddress string = "street_address"
ClaimAddressLocality string = "locality"
ClaimAddressPostalCode string = "postal_code"
ClaimAddressRegion string = "region"
ClaimAddressCountry string = "country"
)

// ClaimsParameterMember as described in https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests
Expand All @@ -19,3 +43,12 @@ type ClaimsParameter struct {
UserInfo map[string]*ClaimsParameterMember `json:"userinfo,omitempty"`
IdToken map[string]*ClaimsParameterMember `json:"id_token,omitempty"`
}

func HasUserInfoClaim(cp *ClaimsParameter, name string) bool {
if cp != nil && cp.UserInfo != nil {
_, exists := cp.UserInfo[name]
return exists
} else {
return false
}
}
1 change: 1 addition & 0 deletions internal/server/handler/authorize/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ func createAuthSession(id string, authorizeRequest *authorizeRequestValues, r *h
ResponseTypes: responseTypes,
Scopes: authorizeRequest.requestedScopes,
State: authorizeRequest.stateParameter,
RequestedClaims: authorizeRequest.requestedClaims,
}
}

Expand Down
150 changes: 115 additions & 35 deletions internal/server/handler/oidc/userinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,43 +42,13 @@ func (h *UserInfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
user := validAccessToken.User
client := validAccessToken.Client
scopes := validAccessToken.Scopes
requestedClaims := validAccessToken.RequestedClaims
response.Subject = user.Username
if slices.Contains(scopes, oidc.ScopeProfile) {
response.Name = user.GetName()
response.GivenName = user.UserProfile.GivenName
response.MiddleName = user.UserProfile.MiddleName
response.FamilyName = user.UserProfile.FamilyName
response.Nickname = user.UserProfile.Nickname
response.PreferredUserName = user.GetPreferredUsername()
response.Gender = user.UserProfile.Gender
response.BirthDate = user.UserProfile.BirthDate
response.ZoneInfo = user.UserProfile.ZoneInfo
response.Locale = user.UserProfile.Locale
response.Website = user.UserProfile.Website
response.Profile = user.UserProfile.Profile
response.Picture = user.UserProfile.Picture
response.UpdatedAt = system.GetStartTime().Unix()
}

if slices.Contains(scopes, oidc.ScopeAddress) && user.UserInformation.Address != nil {
response.Address = &config.UserAddress{}
response.Address.Formatted = user.GetFormattedAddress()
response.Address.Street = user.UserInformation.Address.Street
response.Address.City = user.UserInformation.Address.City
response.Address.PostalCode = user.UserInformation.Address.PostalCode
response.Address.Region = user.UserInformation.Address.Region
response.Address.Country = user.UserInformation.Address.Country
}

if slices.Contains(scopes, oidc.ScopeEmail) {
response.Email = user.UserInformation.Email
response.EmailVerified = user.UserInformation.EmailVerified
}

if slices.Contains(scopes, oidc.ScopePhone) {
response.PhoneNumber = user.UserInformation.PhoneNumber
response.PhoneVerified = user.UserInformation.PhoneVerified
}
applyProfileClaims(user, scopes, requestedClaims, response)
applyAddressClaims(user, scopes, requestedClaims, response)
applyEmailClaims(user, scopes, requestedClaims, response)
applyPhoneClaims(user, scopes, requestedClaims, response)

var result interface{}
result = response
Expand Down Expand Up @@ -110,6 +80,110 @@ func (h *UserInfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

func applyPhoneClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) {
applyClaim(scopes, oidc.ScopePhone, requestedClaims, oidc.ClaimPhoneNumber, func() {
response.PhoneNumber = user.UserInformation.PhoneNumber
})
applyClaim(scopes, oidc.ScopePhone, requestedClaims, oidc.ClaimPhoneNumberVerified, func() {
response.PhoneVerified = user.UserInformation.PhoneVerified
})
}

func applyEmailClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) {
applyClaim(scopes, oidc.ScopeEmail, requestedClaims, oidc.ClaimEmail, func() {
response.Email = user.UserInformation.Email
})
applyClaim(scopes, oidc.ScopeEmail, requestedClaims, oidc.ClaimEmailVerified, func() {
response.EmailVerified = user.UserInformation.EmailVerified
})
}

func applyAddressClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) {
if user.UserInformation.Address != nil {
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressFormatted, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.Formatted = user.GetFormattedAddress()
})
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressStreetAddress, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.Street = user.UserInformation.Address.Street
})
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressLocality, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.City = user.UserInformation.Address.City
})
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressPostalCode, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.PostalCode = user.UserInformation.Address.PostalCode
})
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressRegion, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.Region = user.UserInformation.Address.Region
})
applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressCountry, func() {
if response.Address == nil {
response.Address = &config.UserAddress{}
}
response.Address.Country = user.UserInformation.Address.Country
})
}
}

func applyProfileClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) {
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimName, func() {
response.Name = user.GetName()
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimGivenName, func() {
response.GivenName = user.UserProfile.GivenName
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimMiddleName, func() {
response.MiddleName = user.UserProfile.MiddleName
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimFamilyName, func() {
response.FamilyName = user.UserProfile.FamilyName
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimNickname, func() {
response.Nickname = user.UserProfile.Nickname
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimPreferredUsername, func() {
response.PreferredUserName = user.GetPreferredUsername()
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimGender, func() {
response.Gender = user.UserProfile.Gender
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimBirthdate, func() {
response.BirthDate = user.UserProfile.BirthDate
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimZoneInfo, func() {
response.ZoneInfo = user.UserProfile.ZoneInfo
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimLocale, func() {
response.Locale = user.UserProfile.Locale
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimWebsite, func() {
response.Website = user.UserProfile.Website
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimProfile, func() {
response.Profile = user.UserProfile.Profile
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimPicture, func() {
response.Picture = user.UserProfile.Picture
})
applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimUpdatedAt, func() {
response.UpdatedAt = system.GetStartTime().Unix()
})
}

func updateResponse(response any, key string, value any) (map[string]any, error) {
marshaledResponse, marshalError := json.Marshal(response)
if marshalError != nil {
Expand All @@ -126,3 +200,9 @@ func updateResponse(response any, key string, value any) (map[string]any, error)

return updatedResponse, nil
}

func applyClaim(scopes []string, scope string, requestedClaims *oidc.ClaimsParameter, name string, consumer func()) {
if slices.Contains(scopes, scope) || oidc.HasUserInfoClaim(requestedClaims, name) {
consumer()
}
}

0 comments on commit ad5eb5c

Please sign in to comment.