Skip to content

Commit

Permalink
refactor without global config
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0d committed Apr 7, 2021
1 parent e0748ab commit e7aabee
Show file tree
Hide file tree
Showing 13 changed files with 410 additions and 332 deletions.
12 changes: 8 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package main

import (
"github.com/mesosphere/traefik-forward-auth/internal/api/storage/v1alpha1"
"github.com/mesosphere/traefik-forward-auth/internal/authentication"
"github.com/mesosphere/traefik-forward-auth/internal/configuration"
"github.com/mesosphere/traefik-forward-auth/internal/handlers"
kubernetes "github.com/mesosphere/traefik-forward-auth/internal/kubernetes"
"github.com/mesosphere/traefik-forward-auth/internal/storage"
"github.com/mesosphere/traefik-forward-auth/internal/storage/cluster"
Expand All @@ -10,15 +13,14 @@ import (
"time"

"github.com/gorilla/sessions"
internal "github.com/mesosphere/traefik-forward-auth/internal"
logger "github.com/mesosphere/traefik-forward-auth/internal/log"
k8s "k8s.io/client-go/kubernetes"
)

// Main
func main() {
// Parse options
config := internal.NewGlobalConfig(os.Args[1:])
config := configuration.NewGlobalConfig(os.Args[1:])

// Setup logger
log := logger.NewDefaultLogger(config.LogLevel, config.LogFormat)
Expand All @@ -29,6 +31,7 @@ func main() {
// Query the OIDC provider
config.SetOidcProvider()

authenticator := authentication.NewAuthenticator(config)
// Get clientset for Authorizers
var clientset k8s.Interface
if config.EnableRBAC || config.EnableInClusterStorage {
Expand Down Expand Up @@ -57,7 +60,8 @@ func main() {
config.ClusterStoreNamespace,
string(config.Secret),
config.Lifetime,
time.Duration(config.ClusterStoreCacheTTL)*time.Second)
time.Duration(config.ClusterStoreCacheTTL)*time.Second,
authenticator)

gc := cluster.NewGC(clusterStorage, time.Minute, false, true)

Expand All @@ -66,7 +70,7 @@ func main() {
}
}
// Build server
server := internal.NewServer(userInfoStore, clientset)
server := handlers.NewServer(userInfoStore, clientset, config)

// Attach router to default server
http.HandleFunc("/", server.RootHandler)
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imdario/mergo v0.3.5 h1:JboBksRwiiAJWvIYJVo46AfV+IAIKZpfrSzVKj42R4Q=
github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
Expand Down
192 changes: 69 additions & 123 deletions internal/auth.go → internal/authentication/auth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tfa
package authentication

import (
"crypto/hmac"
Expand All @@ -11,12 +11,22 @@ import (
"strconv"
"strings"
"time"

"github.com/mesosphere/traefik-forward-auth/internal/configuration"
)

type Authenticator struct {
config *configuration.Config
}

func NewAuthenticator(config *configuration.Config) *Authenticator {
return &Authenticator{config}
}

// Request Validation

// Cookie = hash(secret, cookie domain, email, expires)|expires|email|groups
func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
parts := strings.Split(c.Value, "|")

if len(parts) != 3 {
Expand All @@ -28,7 +38,7 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
return "", errors.New("unable to decode cookie mac")
}

expectedSignature := cookieSignature(r, parts[2], parts[1])
expectedSignature := a.cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil {
return "", errors.New("unable to generate mac")
Expand All @@ -54,16 +64,16 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
}

// Validate email
func ValidateEmail(email string) bool {
if len(config.Whitelist) > 0 || len(config.Domains) > 0 {
for _, whitelist := range config.Whitelist {
func (a *Authenticator) ValidateEmail(email string) bool {
if len(a.config.Whitelist) > 0 || len(a.config.Domains) > 0 {
for _, whitelist := range a.config.Whitelist {
if email == whitelist {
return true
}
}

parts := strings.Split(email, "@")
for _, domain := range config.Domains {
for _, domain := range a.config.Domains {
if len(parts) >= 2 && domain == parts[1] {
return true
}
Expand All @@ -73,48 +83,27 @@ func ValidateEmail(email string) bool {
return true
}

// Utility methods

// Get the redirect base
func redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")

return fmt.Sprintf("%s://%s", proto, host)
}

func GetUriPath(r *http.Request) string {
prefix := r.Header.Get("X-Forwarded-Prefix")
uri := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s/%s", strings.TrimRight(prefix, "/"), strings.TrimLeft(uri, "/"))
}

// // Return url
func returnUrl(r *http.Request) string {
return fmt.Sprintf("%s%s", redirectBase(r), GetUriPath(r))
}

// Get oauth redirect uri
func redirectUri(r *http.Request) string {
if use, _ := useAuthDomain(r); use {
func (a *Authenticator) RedirectUri(r *http.Request) string {
if use, _ := a.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
return fmt.Sprintf("%s://%s%s", proto, a.config.AuthHost, a.config.Path)
}

return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
return fmt.Sprintf("%s%s", redirectBase(r), a.config.Path)
}

// Should we use auth host + what it is
func useAuthDomain(r *http.Request) (bool, string) {
if config.AuthHost == "" {
func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) {
if a.config.AuthHost == "" {
return false, ""
}

// Does the request match a given cookie domain?
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
reqMatch, reqHost := a.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))

// Do any of the auth hosts match a cookie domain?
authMatch, authHost := matchCookieDomains(config.AuthHost)
authMatch, authHost := a.matchCookieDomains(a.config.AuthHost)

// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
Expand All @@ -123,59 +112,59 @@ func useAuthDomain(r *http.Request) (bool, string) {
// Cookie methods

// Create an auth cookie
func MakeIDCookie(r *http.Request, email string) *http.Cookie {
expires := cookieExpiry()
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie {
expires := a.cookieExpiry()
mac := a.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)

return &http.Cookie{
Name: config.CookieName,
Name: a.config.CookieName,
Value: value,
Path: "/",
Domain: GetCookieDomain(r),
Domain: a.GetCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Secure: !a.config.InsecureCookie,
Expires: expires,
}
}

// Create a name cookie
func MakeNameCookie(r *http.Request, name string) *http.Cookie {
expires := cookieExpiry()
func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cookie {
expires := a.cookieExpiry()

return &http.Cookie{
Name: config.UserCookieName,
Name: a.config.UserCookieName,
Value: name,
Path: "/",
Domain: GetCookieDomain(r),
Domain: a.GetCookieDomain(r),
HttpOnly: false,
Secure: false,
Expires: expires,
}
}

// Make a CSRF cookie (used during login only)
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: a.config.CSRFCookieName,
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
Domain: a.csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
Secure: !a.config.InsecureCookie,
Expires: a.cookieExpiry(),
}
}

// Create a cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
func (a *Authenticator) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: a.config.CSRFCookieName,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
Domain: a.csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Secure: !a.config.InsecureCookie,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
Expand Down Expand Up @@ -213,18 +202,18 @@ func Nonce() (error, string) {
}

// Cookie domain
func GetCookieDomain(r *http.Request) string {
func (a *Authenticator) GetCookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")

// Check if any of the given cookie domains matches
_, domain := matchCookieDomains(host)
_, domain := a.matchCookieDomains(host)
return domain
}

// Cookie domain
func csrfCookieDomain(r *http.Request) string {
func (a *Authenticator) csrfCookieDomain(r *http.Request) string {
var host string
if use, domain := useAuthDomain(r); use {
if use, domain := a.useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
Expand All @@ -236,12 +225,12 @@ func csrfCookieDomain(r *http.Request) string {
}

// Return matching cookie domain if exists
func matchCookieDomains(domain string) (bool, string) {
func (a *Authenticator) matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")

if config != nil {
for _, d := range config.CookieDomains {
if a.config != nil {
for _, d := range a.config.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
Expand All @@ -251,79 +240,36 @@ func matchCookieDomains(domain string) (bool, string) {
}

// Create cookie hmac
func cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, config.Secret)
hash.Write([]byte(GetCookieDomain(r)))
func (a *Authenticator) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, a.config.Secret)
hash.Write([]byte(a.GetCookieDomain(r)))
hash.Write([]byte(email))
hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
}

// Get cookie expirary
func cookieExpiry() time.Time {
return time.Now().Local().Add(config.Lifetime)
}

// Cookie Domain

// Cookie Domain
type CookieDomain struct {
Domain string `description:"TEST1"`
DomainLen int `description:"TEST2"`
SubDomain string `description:"TEST3"`
SubDomainLen int `description:"TEST4"`
}

func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{
Domain: domain,
DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1,
}
func (a *Authenticator) cookieExpiry() time.Time {
return time.Now().Local().Add(a.config.Lifetime)
}

func (c *CookieDomain) Match(host string) bool {
// Exact domain match?
if host == c.Domain {
return true
}

// Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true
}

return false
}
// Utility methods

func (c *CookieDomain) UnmarshalFlag(value string) error {
*c = *NewCookieDomain(value)
return nil
}
// Get the redirect base
func redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")

func (c *CookieDomain) MarshalFlag() (string, error) {
return c.Domain, nil
return fmt.Sprintf("%s://%s", proto, host)
}

// Legacy support for comma separated list of cookie domains

type CookieDomains []CookieDomain

func (c *CookieDomains) UnmarshalFlag(value string) error {
if len(value) > 0 {
for _, d := range strings.Split(value, ",") {
cookieDomain := NewCookieDomain(d)
*c = append(*c, *cookieDomain)
}
}
return nil
func GetUriPath(r *http.Request) string {
prefix := r.Header.Get("X-Forwarded-Prefix")
uri := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s/%s", strings.TrimRight(prefix, "/"), strings.TrimLeft(uri, "/"))
}

func (c *CookieDomains) MarshalFlag() (string, error) {
var domains []string
for _, d := range *c {
domains = append(domains, d.Domain)
}
return strings.Join(domains, ","), nil
// // Return url
func ReturnUrl(r *http.Request) string {
return fmt.Sprintf("%s%s", redirectBase(r), GetUriPath(r))
}
Loading

0 comments on commit e7aabee

Please sign in to comment.