diff --git a/internal/manager/token/token.go b/internal/manager/token/token.go index 56c7cd9..eff90b8 100644 --- a/internal/manager/token/token.go +++ b/internal/manager/token/token.go @@ -31,9 +31,10 @@ type clientStores struct { } type ValidAccessToken struct { - User *config.User - Client *config.Client - Scopes []string + User *config.User + Client *config.Client + Scopes []string + RequestedClaims *oidc.ClaimsParameter } type Manager struct { @@ -286,9 +287,10 @@ func (tokenManager *Manager) validateAccessToken(accessTokenValue string) (*Vali } validAccessToken := &ValidAccessToken{ - User: user, - Client: client, - Scopes: accessToken.Scopes, + User: user, + Client: client, + Scopes: accessToken.Scopes, + RequestedClaims: accessToken.RequestedClaims, } return validAccessToken, true diff --git a/internal/oidc/claims.go b/internal/oidc/claims.go index 5168c8f..d6d5e01 100644 --- a/internal/oidc/claims.go +++ b/internal/oidc/claims.go @@ -1,10 +1,34 @@ package oidc const ( - ClaimNonce string = "nonce" - ClaimAuthorizedParty string = "azp" - ClaimAtHash string = "at_hash" - ClaimAuthTime string = "auth_time" + ClaimNonce string = "nonce" + ClaimAuthorizedParty string = "azp" + ClaimAtHash string = "at_hash" + ClaimAuthTime string = "auth_time" + ClaimName string = "name" + ClaimGivenName string = "given_name" + ClaimMiddleName string = "middle_name" + ClaimFamilyName string = "family_name" + ClaimNickname string = "nickname" + ClaimPreferredUsername string = "preferred_username" + ClaimGender string = "gender" + ClaimBirthdate string = "birthdate" + ClaimZoneInfo string = "zoneinfo" + ClaimLocale string = "locale" + ClaimWebsite string = "website" + ClaimProfile string = "profile" + ClaimPicture string = "picture" + ClaimEmail string = "email" + ClaimEmailVerified string = "email_verified" + ClaimPhoneNumber string = "phone_number" + ClaimPhoneNumberVerified string = "phone_number_verified" + ClaimUpdatedAt string = "updated_at" + ClaimAddressFormatted string = "formatted" + ClaimAddressStreetAddress string = "street_address" + ClaimAddressLocality string = "locality" + ClaimAddressPostalCode string = "postal_code" + ClaimAddressRegion string = "region" + ClaimAddressCountry string = "country" ) // ClaimsParameterMember as described in https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests @@ -19,3 +43,12 @@ type ClaimsParameter struct { UserInfo map[string]*ClaimsParameterMember `json:"userinfo,omitempty"` IdToken map[string]*ClaimsParameterMember `json:"id_token,omitempty"` } + +func HasUserInfoClaim(cp *ClaimsParameter, name string) bool { + if cp != nil && cp.UserInfo != nil { + _, exists := cp.UserInfo[name] + return exists + } else { + return false + } +} diff --git a/internal/server/handler/authorize/authorize.go b/internal/server/handler/authorize/authorize.go index 464b13e..2b5a3bb 100644 --- a/internal/server/handler/authorize/authorize.go +++ b/internal/server/handler/authorize/authorize.go @@ -637,6 +637,7 @@ func createAuthSession(id string, authorizeRequest *authorizeRequestValues, r *h ResponseTypes: responseTypes, Scopes: authorizeRequest.requestedScopes, State: authorizeRequest.stateParameter, + RequestedClaims: authorizeRequest.requestedClaims, } } diff --git a/internal/server/handler/oidc/userinfo.go b/internal/server/handler/oidc/userinfo.go index d4d3d9f..9ee4863 100644 --- a/internal/server/handler/oidc/userinfo.go +++ b/internal/server/handler/oidc/userinfo.go @@ -42,43 +42,13 @@ func (h *UserInfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { user := validAccessToken.User client := validAccessToken.Client scopes := validAccessToken.Scopes + requestedClaims := validAccessToken.RequestedClaims response.Subject = user.Username - if slices.Contains(scopes, oidc.ScopeProfile) { - response.Name = user.GetName() - response.GivenName = user.UserProfile.GivenName - response.MiddleName = user.UserProfile.MiddleName - response.FamilyName = user.UserProfile.FamilyName - response.Nickname = user.UserProfile.Nickname - response.PreferredUserName = user.GetPreferredUsername() - response.Gender = user.UserProfile.Gender - response.BirthDate = user.UserProfile.BirthDate - response.ZoneInfo = user.UserProfile.ZoneInfo - response.Locale = user.UserProfile.Locale - response.Website = user.UserProfile.Website - response.Profile = user.UserProfile.Profile - response.Picture = user.UserProfile.Picture - response.UpdatedAt = system.GetStartTime().Unix() - } - - if slices.Contains(scopes, oidc.ScopeAddress) && user.UserInformation.Address != nil { - response.Address = &config.UserAddress{} - response.Address.Formatted = user.GetFormattedAddress() - response.Address.Street = user.UserInformation.Address.Street - response.Address.City = user.UserInformation.Address.City - response.Address.PostalCode = user.UserInformation.Address.PostalCode - response.Address.Region = user.UserInformation.Address.Region - response.Address.Country = user.UserInformation.Address.Country - } - if slices.Contains(scopes, oidc.ScopeEmail) { - response.Email = user.UserInformation.Email - response.EmailVerified = user.UserInformation.EmailVerified - } - - if slices.Contains(scopes, oidc.ScopePhone) { - response.PhoneNumber = user.UserInformation.PhoneNumber - response.PhoneVerified = user.UserInformation.PhoneVerified - } + applyProfileClaims(user, scopes, requestedClaims, response) + applyAddressClaims(user, scopes, requestedClaims, response) + applyEmailClaims(user, scopes, requestedClaims, response) + applyPhoneClaims(user, scopes, requestedClaims, response) var result interface{} result = response @@ -110,6 +80,110 @@ func (h *UserInfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func applyPhoneClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) { + applyClaim(scopes, oidc.ScopePhone, requestedClaims, oidc.ClaimPhoneNumber, func() { + response.PhoneNumber = user.UserInformation.PhoneNumber + }) + applyClaim(scopes, oidc.ScopePhone, requestedClaims, oidc.ClaimPhoneNumberVerified, func() { + response.PhoneVerified = user.UserInformation.PhoneVerified + }) +} + +func applyEmailClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) { + applyClaim(scopes, oidc.ScopeEmail, requestedClaims, oidc.ClaimEmail, func() { + response.Email = user.UserInformation.Email + }) + applyClaim(scopes, oidc.ScopeEmail, requestedClaims, oidc.ClaimEmailVerified, func() { + response.EmailVerified = user.UserInformation.EmailVerified + }) +} + +func applyAddressClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) { + if user.UserInformation.Address != nil { + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressFormatted, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.Formatted = user.GetFormattedAddress() + }) + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressStreetAddress, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.Street = user.UserInformation.Address.Street + }) + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressLocality, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.City = user.UserInformation.Address.City + }) + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressPostalCode, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.PostalCode = user.UserInformation.Address.PostalCode + }) + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressRegion, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.Region = user.UserInformation.Address.Region + }) + applyClaim(scopes, oidc.ScopeAddress, requestedClaims, oidc.ClaimAddressCountry, func() { + if response.Address == nil { + response.Address = &config.UserAddress{} + } + response.Address.Country = user.UserInformation.Address.Country + }) + } +} + +func applyProfileClaims(user *config.User, scopes []string, requestedClaims *oidc.ClaimsParameter, response *UserInfoResponse) { + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimName, func() { + response.Name = user.GetName() + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimGivenName, func() { + response.GivenName = user.UserProfile.GivenName + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimMiddleName, func() { + response.MiddleName = user.UserProfile.MiddleName + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimFamilyName, func() { + response.FamilyName = user.UserProfile.FamilyName + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimNickname, func() { + response.Nickname = user.UserProfile.Nickname + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimPreferredUsername, func() { + response.PreferredUserName = user.GetPreferredUsername() + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimGender, func() { + response.Gender = user.UserProfile.Gender + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimBirthdate, func() { + response.BirthDate = user.UserProfile.BirthDate + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimZoneInfo, func() { + response.ZoneInfo = user.UserProfile.ZoneInfo + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimLocale, func() { + response.Locale = user.UserProfile.Locale + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimWebsite, func() { + response.Website = user.UserProfile.Website + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimProfile, func() { + response.Profile = user.UserProfile.Profile + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimPicture, func() { + response.Picture = user.UserProfile.Picture + }) + applyClaim(scopes, oidc.ScopeProfile, requestedClaims, oidc.ClaimUpdatedAt, func() { + response.UpdatedAt = system.GetStartTime().Unix() + }) +} + func updateResponse(response any, key string, value any) (map[string]any, error) { marshaledResponse, marshalError := json.Marshal(response) if marshalError != nil { @@ -126,3 +200,9 @@ func updateResponse(response any, key string, value any) (map[string]any, error) return updatedResponse, nil } + +func applyClaim(scopes []string, scope string, requestedClaims *oidc.ClaimsParameter, name string, consumer func()) { + if slices.Contains(scopes, scope) || oidc.HasUserInfoClaim(requestedClaims, name) { + consumer() + } +}