Skip to content

Commit

Permalink
Use typical error patterns
Browse files Browse the repository at this point in the history
`makeSessionCookie` turns its err value into a nil, losing its info.

Also make fatal problems use Error and non-fatal problems use Warn.

cherry-pick 3e4e6f9
  • Loading branch information
jongiddy authored and dkoshkin committed Feb 3, 2022
1 parent 0b2f10c commit 08f8b4d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
8 changes: 5 additions & 3 deletions internal/authentication/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) {
// Cookie methods

// MakeIDCookie creates an auth cookie
func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) *http.Cookie {
func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) (*http.Cookie, error) {
expires := a.config.CookieExpiry()
data := &ID{
Email: email,
Expand All @@ -110,10 +110,10 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string

encoded, err := a.secureCookie.Encode(a.config.CookieName, data)
if err != nil {
return nil
return nil, err
}

return &http.Cookie{
cookie := &http.Cookie{
Name: a.config.CookieName,
Value: encoded,
Path: "/",
Expand All @@ -122,6 +122,8 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string
Secure: !a.config.InsecureCookie,
Expires: expires,
}

return cookie, nil
}

// MakeNameCookie creates a name cookie
Expand Down
14 changes: 9 additions & 5 deletions internal/authentication/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func TestAuthValidateCookie(t *testing.T) {
// Should catch expired
config.Lifetime = time.Second * time.Duration(-1)
a = NewAuthenticator(config)
c = a.MakeIDCookie(r, "[email protected]", "")
c, err = a.MakeIDCookie(r, "[email protected]", "")
assert.Nil(err)
_, err = a.ValidateCookie(r, c)
if assert.Error(err) {
assert.Equal("securecookie: expired timestamp", err.Error())
Expand All @@ -61,7 +62,8 @@ func TestAuthValidateCookie(t *testing.T) {
// Should accept valid cookie
config.Lifetime = time.Second * time.Duration(10)
a = NewAuthenticator(config)
c = a.MakeIDCookie(r, "[email protected]", "")
c, err = a.MakeIDCookie(r, "[email protected]", "")
assert.Nil(err)
id, err := a.ValidateCookie(r, c)
assert.Nil(err, "valid request should not return an error")
assert.Equal("[email protected]", id.Email, "valid request should return user email")
Expand Down Expand Up @@ -124,10 +126,11 @@ func TestAuthMakeCookie(t *testing.T) {
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
r.Header.Add("X-Forwarded-Host", "app.example.com")

c := a.MakeIDCookie(r, "[email protected]", "")
c, err := a.MakeIDCookie(r, "[email protected]", "")
assert.Nil(err)
assert.Equal("_forward_auth", c.Name)
assert.Greater(len(c.Value), 18, "encoded securecookie should be longer")
_, err := a.ValidateCookie(r, c)
_, err = a.ValidateCookie(r, c)
assert.Nil(err, "should generate valid cookie")
assert.Equal("/", c.Path)
assert.Equal("app.example.com", c.Domain)
Expand All @@ -138,7 +141,8 @@ func TestAuthMakeCookie(t *testing.T) {

config.CookieName = "testname"
config.InsecureCookie = true
c = a.MakeIDCookie(r, "[email protected]", "")
c, err = a.MakeIDCookie(r, "[email protected]", "")
assert.Nil(err)
assert.Equal("testname", c.Name)
assert.False(c.Secure)
}
Expand Down
26 changes: 16 additions & 10 deletions internal/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,15 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Check for CSRF cookie
c, err := r.Cookie(s.config.CSRFCookieName)
if err != nil {
logger.Warnf("missing CSRF cookie: %v", err)
logger.Errorf("missing CSRF cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}

// Validate state
valid, redirect, err := authentication.ValidateCSRFCookie(r, c)
if !valid {
logger.Warnf("error validating CSRF cookie: %v", err)
logger.Errorf("error validating CSRF cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
Expand Down Expand Up @@ -301,15 +301,15 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Exchange code for token
oauth2Token, err := oauth2Config.Exchange(s.config.OIDCContext, r.URL.Query().Get("code"))
if err != nil {
logger.Warnf("failed to exchange token: %v", err)
logger.Errorf("failed to exchange token: %v", err)
http.Error(w, "Bad Gateway", 502)
return
}

// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
logger.Warnf("missing ID token: %v", err)
logger.Error("missing ID token")
http.Error(w, "Bad Gateway", 502)
return
}
Expand All @@ -318,15 +318,15 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
verifier := provider.Verifier(&oidc.Config{ClientID: s.config.ClientID})
idToken, err := verifier.Verify(s.config.OIDCContext, rawIDToken)
if err != nil {
logger.Warnf("failed to verify token: %v", err)
logger.Errorf("failed to verify token: %v", err)
http.Error(w, "Bad Gateway", 502)
return
}

// Extract custom claims
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
logger.Warnf("failed to extract claims: %v", err)
logger.Errorf("failed to extract claims: %v", err)
http.Error(w, "Bad Gateway", 502)
return
}
Expand All @@ -339,12 +339,18 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}

// Generate cookies
http.SetCookie(w, s.authenticator.MakeIDCookie(r, email.(string), token))
c, err := s.authenticator.MakeIDCookie(r, email.(string), token)
if err != nil {
logger.Errorf("error generating secure session cookie: %v", err)
http.Error(w, "Bad Gateway", 502)
return
}
http.SetCookie(w, c)
logger.WithFields(logrus.Fields{
"user": claims["email"].(string),
}).Infof("generated auth cookie")
} else {
logger.Errorf("no email claim present in the ID token")
logger.Warn("no email claim present in the ID token")
}

// If name in null, empty or whitespace, use email address for name
Expand All @@ -356,7 +362,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
http.SetCookie(w, s.authenticator.MakeNameCookie(r, name.(string)))
logger.WithFields(logrus.Fields{
"name": name.(string),
}).Infof("generated name cookie")
}).Info("generated name cookie")

// Mapping groups
groups := []string{}
Expand All @@ -366,7 +372,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
groups = append(groups, g.(string))
}
} else {
logger.Errorf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName)
logger.Warnf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName)
}

if err := s.userinfo.Save(r, w, &v1alpha1.UserInfo{
Expand Down
12 changes: 8 additions & 4 deletions internal/handlers/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
req = newDefaultHTTPRequest("/foo")
// NOTE(jkoelker) `notAuthenticated` will redirect if it thinks the request is from a browser
req.Header.Set("Accept", "application/json")
c := a.MakeIDCookie(req, "[email protected]", "")
c, err := a.MakeIDCookie(req, "[email protected]", "")
assert.Nil(err)
config = newTestConfig(testAuthKey2, testEncKey2) // new auth & encryption key!

config.AuthHost = ""
Expand All @@ -115,7 +116,8 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
// Should validate email
req = newDefaultHTTPRequest("/foo")
a = authentication.NewAuthenticator(config)
c = a.MakeIDCookie(req, "[email protected]", "")
c, err = a.MakeIDCookie(req, "[email protected]", "")
assert.Nil(err)
config.Domains = []string{"test.com"}

res, _ = doHttpRequest(req, c, config)
Expand All @@ -132,7 +134,8 @@ func TestServerAuthHandlerExpired(t *testing.T) {

// Should redirect expired cookie
req := newDefaultHTTPRequest("/foo")
c := a.MakeIDCookie(req, "[email protected]", "")
c, err := a.MakeIDCookie(req, "[email protected]", "")
assert.Nil(err)
res, _ := doHttpRequest(req, c, config)
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")

Expand All @@ -151,7 +154,8 @@ func TestServerAuthHandlerValid(t *testing.T) {

// Should allow valid request email
req := newDefaultHTTPRequest("/foo")
c := a.MakeIDCookie(req, "[email protected]", "")
c, err := a.MakeIDCookie(req, "[email protected]", "")
assert.Nil(err)

config.Domains = []string{}

Expand Down

0 comments on commit 08f8b4d

Please sign in to comment.