Skip to content

Commit

Permalink
formatting, linting, documenting - without functional changes
Browse files Browse the repository at this point in the history
cherry-pick 7e59a13
  • Loading branch information
Mario Hros authored and dkoshkin committed Feb 2, 2022
1 parent 67f8234 commit ddb3eab
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 83 deletions.
8 changes: 5 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func main() {
config.Validate()

// Query the OIDC provider
config.SetOidcProvider()
if err := config.LoadOIDCProviderConfiguration(); err != nil {
log.Fatalln(err.Error())
}

authenticator := authentication.NewAuthenticator(config)
// Get clientset for Authorizers
Expand Down Expand Up @@ -77,7 +79,7 @@ func main() {
http.HandleFunc("/", server.RootHandler)

// Start
log.Debugf("Starting with options: %s", config)
log.Info("Listening on :4181")
log.Debugf("starting with options: %s", config)
log.Info("listening on :4181")
log.Info(http.ListenAndServe(":4181", nil))
}
62 changes: 38 additions & 24 deletions internal/authentication/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ func NewAuthenticator(config *configuration.Config) *Authenticator {

// Request Validation

// Cookie = hash(secret, cookie domain, email, expires)|expires|email|groups
// ValidateCookie validates the ID cookie in the request
// IDCookie = hash(secret, cookie domain, email, expires)|expires|email|group
func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
parts := strings.Split(c.Value, "|")

Expand Down Expand Up @@ -63,7 +64,8 @@ func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string,
return parts[2], nil
}

// Validate email
// ValidateEmail validates that the provided email ends with one of the configured Domains or is part of the configured Whitelist.
// Also returns true if there is no Whitelist and no Domains configured.
func (a *Authenticator) ValidateEmail(email string) bool {
if len(a.config.Whitelist) > 0 || len(a.config.Domains) > 0 {
for _, whitelist := range a.config.Whitelist {
Expand All @@ -83,17 +85,20 @@ func (a *Authenticator) ValidateEmail(email string) bool {
return true
}

// Get oauth redirect uri
func (a *Authenticator) RedirectUri(r *http.Request) string {
// ComposeRedirectURI generates oauth redirect uri to return to from the OAuth2 provider
func (a *Authenticator) ComposeRedirectURI(r *http.Request) string {
if use, _ := a.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, a.config.AuthHost, a.config.Path)
scheme := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", scheme, a.config.AuthHost, a.config.Path)

}

return fmt.Sprintf("%s%s", redirectBase(r), a.config.Path)
return fmt.Sprintf("%s%s", getRequestSchemeHost(r), a.config.Path)
}

// Should we use auth host + what it is
// useAuthDomain decides whether the host of the forwarded request
// matches the configured AuthHost and whether we can configure cookies for the AuthHost
// If it does, the function returns true and the top-level domain from the config we can use
func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) {
if a.config.AuthHost == "" {
return false, ""
Expand All @@ -111,7 +116,7 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) {

// Cookie methods

// Create an auth cookie
// MakeIDCookie creates an auth cookie
func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie {
expires := a.cookieExpiry()
mac := a.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
Expand All @@ -128,7 +133,7 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie
}
}

// Create a name cookie
// MakeNameCookie creates a name cookie
func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cookie {
expires := a.cookieExpiry()

Expand All @@ -143,7 +148,7 @@ func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cooki
}
}

// Make a CSRF cookie (used during login only)
// MakeCSRFCookie creates a CSRF cookie (used during login only)
func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: a.config.CSRFCookieName,
Expand All @@ -156,7 +161,7 @@ func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cook
}
}

// Create a cookie to clear csrf cookie
// ClearCSRFCookie clears the csrf cookie
func (a *Authenticator) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: a.config.CSRFCookieName,
Expand All @@ -169,7 +174,7 @@ func (a *Authenticator) ClearCSRFCookie(r *http.Request) *http.Cookie {
}
}

// Validate the csrf cookie against state
// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
state := r.URL.Query().Get("state")

Expand All @@ -190,15 +195,16 @@ func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
return true, state[33:], nil
}

func Nonce() (error, string) {
// GenerateNonce generates a random nonce string
func GenerateNonce() (string, error) {
// Make nonce
nonce := make([]byte, 16)
_, err := rand.Read(nonce)
if err != nil {
return err, ""
return "", err
}

return nil, fmt.Sprintf("%x", nonce)
return fmt.Sprintf("%x", nonce), nil
}

// Cookie domain
Expand All @@ -224,7 +230,10 @@ func (a *Authenticator) csrfCookieDomain(r *http.Request) string {
return p[0]
}

// Return matching cookie domain if exists
// matchCookieDomains checks if the provided domain maches any domain configured in the CookieDomains list
// and returns the domain from the list it matched with.
// The match is either the direct equality of domain names or the input subdomain (e.g. "a.test.com") belongs under a configured top domain ("test.com").
// If the domain does not match CookieDomains, false is returned with the input domain as the second return value.
func (a *Authenticator) matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
Expand All @@ -240,7 +249,7 @@ func (a *Authenticator) matchCookieDomains(domain string) (bool, string) {
return false, p[0]
}

// Create cookie hmac
// cookieSignature creates a cookie hmac
func (a *Authenticator) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, a.config.Secret)
hash.Write([]byte(a.GetCookieDomain(r)))
Expand All @@ -256,21 +265,26 @@ func (a *Authenticator) cookieExpiry() time.Time {

// Utility methods

// Get the redirect base
func redirectBase(r *http.Request) string {
// getRequestSchemeHost returns scheme://host part of the request
// Example output: "https://domain.com"
func getRequestSchemeHost(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")

return fmt.Sprintf("%s://%s", proto, host)
}

func GetUriPath(r *http.Request) string {
// GetRequestURI returns the full request URI with query parameters.
// The path includes the prefix (if stripPrefix middleware was used).
// Example output: "/prefix/path?query=1"
func GetRequestURI(r *http.Request) string {
prefix := r.Header.Get("X-Forwarded-Prefix")
uri := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s/%s", strings.TrimRight(prefix, "/"), strings.TrimLeft(uri, "/"))
}

// // Return url
func ReturnUrl(r *http.Request) string {
return fmt.Sprintf("%s%s", redirectBase(r), GetUriPath(r))
// GetRequestURL returns full requst URL scheme://host/uri with query params
// Example output: "https://domain.com/prefix/path?query=1"
func GetRequestURL(r *http.Request) string {
return fmt.Sprintf("%s%s", getRequestSchemeHost(r), GetRequestURI(r))
}
4 changes: 2 additions & 2 deletions internal/authentication/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ func TestAuthValidateCSRFCookie(t *testing.T) {

func TestAuthNonce(t *testing.T) {
assert := assert.New(t)
err, nonce1 := Nonce()
nonce1, err := GenerateNonce()
assert.Nil(err, "error generating nonce")
assert.Len(nonce1, 32, "length should be 32 chars")

err, nonce2 := Nonce()
nonce2, err := GenerateNonce()
assert.Nil(err, "error generating nonce")
assert.Len(nonce2, 32, "length should be 32 chars")

Expand Down
1 change: 1 addition & 0 deletions internal/authorization/authorizer.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package authorization

// Authorizer is the interface for implementing user authorization (check to see if the user can perform the action)
type Authorizer interface {
Authorize(user User, requestVerb, requestResource string) (bool, error)
}
58 changes: 40 additions & 18 deletions internal/authorization/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ const (
cacheSyncDuration = time.Minute * 10
)

type RBACAuthorizer struct {
// Authorizer implements the authorizer by watching and using ClusterRole and ClusterRoleBinding Kubernetes (RBAC) objects
type Authorizer struct {
clientset kubernetes.Interface
clusterRoleLister rbaclisterv1.ClusterRoleLister
clusterRoleBindingLister rbaclisterv1.ClusterRoleBindingLister
Expand All @@ -29,8 +30,9 @@ type RBACAuthorizer struct {
selector labels.Selector
}

func NewRBACAuthorizer(clientset kubernetes.Interface) *RBACAuthorizer {
authz := &RBACAuthorizer{
// NewAuthorizer creates a new RBAC authorizer
func NewAuthorizer(clientset kubernetes.Interface) *Authorizer {
authz := &Authorizer{
clientset: clientset,
syncDuration: cacheSyncDuration,
selector: labels.NewSelector(),
Expand All @@ -41,7 +43,9 @@ func NewRBACAuthorizer(clientset kubernetes.Interface) *RBACAuthorizer {
}

// Private
func (ra *RBACAuthorizer) getRoleByName(name string) *rbacv1.ClusterRole {

// getRoleByName finds the ClusterRole by its name or returns nil
func (ra *Authorizer) getRoleByName(name string) *rbacv1.ClusterRole {
clusterRole, err := ra.clusterRoleLister.Get(name)
if err != nil {
if errors.IsNotFound(err) {
Expand All @@ -59,25 +63,32 @@ func (ra *RBACAuthorizer) getRoleByName(name string) *rbacv1.ClusterRole {
return clusterRole
}

func (ra *RBACAuthorizer) getRoleFromGroups(target, role string, groups []string) *rbacv1.ClusterRole {
for _, group := range groups {
if group == target {
return ra.getRoleByName(role)
// getRoleFromGroups returns role specified in roleNameRef only if subjectGroupName is in the userGroups list
func (ra *Authorizer) getRoleFromGroups(roleNameRef, subjectGroupName string, userGroups []string) *rbacv1.ClusterRole {
// for every user group...
for _, group := range userGroups {
// if the group matches the group name in the subject, return the role
if group == subjectGroupName {
return ra.getRoleByName(roleNameRef)
}
}

// no user group match this subjectGroupName
return nil
}

func (ra *RBACAuthorizer) getRoleForSubject(user authorization.User, subject rbacv1.Subject, role string) *rbacv1.ClusterRole {
// getRoleForSubject gets the role bound to the subject depending on the subject kind (user or group).
// Returns nil if there is no rule matching or an unknown subject Kind is provided
func (ra *Authorizer) getRoleForSubject(user authorization.User, subject rbacv1.Subject, roleNameRef string) *rbacv1.ClusterRole {
if subject.Kind == "User" && subject.Name == user.GetName() {
return ra.getRoleByName(role)
return ra.getRoleByName(roleNameRef)
} else if subject.Kind == "Group" {
return ra.getRoleFromGroups(subject.Name, role, user.GetGroups())
return ra.getRoleFromGroups(roleNameRef, subject.Name, user.GetGroups())
}
return nil
}

func (ra *RBACAuthorizer) prepareCache() {
func (ra *Authorizer) prepareCache() {
ra.sharedInformerFactory = informers.NewSharedInformerFactory(ra.clientset, ra.syncDuration)
ra.clusterRoleLister = ra.sharedInformerFactory.Rbac().V1().ClusterRoles().Lister()
ra.clusterRoleBindingLister = ra.sharedInformerFactory.Rbac().V1().ClusterRoleBindings().Lister()
Expand All @@ -86,7 +97,9 @@ func (ra *RBACAuthorizer) prepareCache() {
}

// Public
func (ra *RBACAuthorizer) GetRoles(user authorization.User) (*rbacv1.ClusterRoleList, error) {

// GetRolesBoundToUser returns list of roles bound to the specified user or groups the user is part of
func (ra *Authorizer) GetRolesBoundToUser(user authorization.User) (*rbacv1.ClusterRoleList, error) {
clusterRoles := rbacv1.ClusterRoleList{}
clusterRoleBindings, err := ra.clusterRoleBindingLister.List(ra.selector)
if err != nil {
Expand All @@ -105,29 +118,37 @@ func (ra *RBACAuthorizer) GetRoles(user authorization.User) (*rbacv1.ClusterRole
}

// Interface methods
func (ra *RBACAuthorizer) Authorize(user authorization.User, requestVerb, requestResource string) (bool, error) {
roles, err := ra.GetRoles(user)

// Authorize performs the authorization logic
func (ra *Authorizer) Authorize(user authorization.User, requestVerb, requestResource string) (bool, error) {
roles, err := ra.GetRolesBoundToUser(user)
if err != nil {
return false, err
}

// deny if no roles defined
if len(roles.Items) < 1 {
return false, nil
}

// check all rules in the list of roles to see if any matches
for _, role := range roles.Items {
for _, rule := range role.Rules {
if VerbMatches(&rule, requestVerb) && NonResourceURLMatches(&rule, requestResource) {
if verbMatches(&rule, requestVerb) && nonResourceURLMatches(&rule, requestResource) {
return true, nil
}
}
}

// no rules match the request -> deny
return false, nil
}

// Utility
func VerbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool {

// verbMatches returns true if the requested verb matches a verb specifid in the rule
// Also matches if the rule mentiones special "all verbs" rule *
func verbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool {
for _, ruleVerb := range rule.Verbs {
if ruleVerb == rbacv1.VerbAll {
return true
Expand All @@ -140,7 +161,8 @@ func VerbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool {
return false
}

func NonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool {
// nonResourceURLMatches returns true if the requested URL matches a policy the rule
func nonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool {
for _, ruleURL := range rule.NonResourceURLs {
if ruleURL == rbacv1.NonResourceAll {
return true
Expand Down
8 changes: 4 additions & 4 deletions internal/authorization/rbac/rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ type testCase struct {
should bool
}

func getRBACAuthorizer(objs ...runtime.Object) *RBACAuthorizer {
return NewRBACAuthorizer(fake.NewSimpleClientset(objs...))
func getRBACAuthorizer(objs ...runtime.Object) *Authorizer {
return NewAuthorizer(fake.NewSimpleClientset(objs...))
}

func makeRole(name string, verbs, urls []string) rbacv1.ClusterRole {
Expand Down Expand Up @@ -98,7 +98,7 @@ func TestRBACAuthorizer_GetRoles(t *testing.T) {
a := getRBACAuthorizer(roles, bindings)

u1 := authorization.User{Name: "u1"}
r, err := a.GetRoles(u1)
r, err := a.GetRolesBoundToUser(u1)

assert.NilError(t, err)
assert.Equal(t, len(r.Items), 2)
Expand All @@ -107,7 +107,7 @@ func TestRBACAuthorizer_GetRoles(t *testing.T) {

u2 := authorization.User{Name: "u2", Groups: []string{"g1", "g2"}}

r, err = a.GetRoles(u2)
r, err = a.GetRolesBoundToUser(u2)
assert.NilError(t, err)
assert.Equal(t, len(r.Items), 1)
assert.Equal(t, r.Items[0].Name, "r3")
Expand Down
3 changes: 3 additions & 0 deletions internal/authorization/user.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package authorization

// User represents an autorized user
type User struct {
Name string
Groups []string
}

// GetName returns the user name
func (k *User) GetName() string {
return k.Name
}

// GetGroups return list of groups the user belongs to
func (k *User) GetGroups() []string {
return k.Groups
}
1 change: 1 addition & 0 deletions internal/authorization/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"strings"
)

// PathMatches returns true if the URL matches the pattern containing an optional wildcard '*' character
func PathMatches(url, pattern string) bool {
return pattern == url ||
(strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*")))
Expand Down
Loading

0 comments on commit ddb3eab

Please sign in to comment.