Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add reuse interval for token refresh #466

Merged
merged 12 commits into from
May 5, 2022
56 changes: 41 additions & 15 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,26 +288,50 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}
}

var newToken *models.RefreshToken
if token.Revoked {
a.clearCookieTokens(config, w)
if config.Security.RefreshTokenRotationEnabled {
// Revoke all tokens in token family
err = a.db.Transaction(func(tx *storage.Connection) error {
var terr error
if terr = models.RevokeTokenFamily(tx, token); terr != nil {
return terr
err = a.db.Transaction(func(tx *storage.Connection) error {
validToken, terr := models.GetCurrentValidToken(tx, token)
if terr != nil {
if errors.Is(terr, models.RefreshTokenNotFoundError{}) {
// revoked token has no descendants
return nil
}
return terr
}
// check if token is the last previous revoked token
if validToken.Parent == storage.NullString(token.Token) {
refreshTokenReuseWindow := token.UpdatedAt.Add(time.Second * time.Duration(config.Security.RefreshTokenReuseInterval))
if time.Now().Before(refreshTokenReuseWindow) {
newToken = validToken
}
return nil
})
if err != nil {
return internalServerError(err.Error())
}
return nil
})
if err != nil {
return internalServerError("Error validating reuse interval").WithInternalError(err)
}

if newToken == nil {
if config.Security.RefreshTokenRotationEnabled {
// Revoke all tokens in token family
err = a.db.Transaction(func(tx *storage.Connection) error {
var terr error
if terr = models.RevokeTokenFamily(tx, token); terr != nil {
return terr
}
return nil
})
if err != nil {
return internalServerError(err.Error())
}
}
return oauthError("invalid_grant", "Invalid Refresh Token").WithInternalMessage("Possible abuse attempt: %v", r)
}
return oauthError("invalid_grant", "Invalid Refresh Token").WithInternalMessage("Possible abuse attempt: %v", r)
}

var tokenString string
var newToken *models.RefreshToken
var newTokenResponse *AccessTokenResponse

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -316,9 +340,11 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return terr
}

newToken, terr = models.GrantRefreshTokenSwap(tx, user, token)
if terr != nil {
return internalServerError(terr.Error())
if newToken == nil {
newToken, terr = models.GrantRefreshTokenSwap(tx, user, token)
if terr != nil {
return internalServerError(terr.Error())
}
}

tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
Expand Down
89 changes: 89 additions & 0 deletions api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,95 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
u, err := models.NewUser(ts.instanceID, "[email protected]", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving foo user")
first, err := models.GrantAuthenticatedUser(ts.API.db, u)
require.NoError(ts.T(), err)
second, err := models.GrantRefreshTokenSwap(ts.API.db, u, first)
require.NoError(ts.T(), err)
third, err := models.GrantRefreshTokenSwap(ts.API.db, u, second)
require.NoError(ts.T(), err)

cases := []struct {
desc string
refreshTokenRotationEnabled bool
reuseInterval int
refreshToken string
expectedCode int
expectedBody map[string]interface{}
}{
{
"Valid refresh within reuse interval",
true,
30,
second.Token,
http.StatusOK,
map[string]interface{}{
"refresh_token": third.Token,
},
},
{
"Invalid refresh, first token is not the previous revoked token",
true,
0,
first.Token,
http.StatusBadRequest,
map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
},
},
{
"Invalid refresh, revoked third token",
true,
0,
second.Token,
http.StatusBadRequest,
map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
},
},
{
"Invalid refresh, third token revoked by previous case",
true,
30,
third.Token,
http.StatusBadRequest,
map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
},
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.RefreshTokenRotationEnabled = c.refreshTokenRotationEnabled
ts.Config.Security.RefreshTokenReuseInterval = c.reuseInterval
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": c.refreshToken,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), c.expectedCode, w.Code)

data := make(map[string]interface{})
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
for k, v := range c.expectedBody {
require.Equal(ts.T(), v, data[k])
}
})
}
}

func (ts *TokenTestSuite) createBannedUser() *models.User {
u, err := models.NewUser(ts.instanceID, "[email protected]", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
Expand Down
1 change: 1 addition & 0 deletions conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ type CaptchaConfiguration struct {
type SecurityConfiguration struct {
Captcha CaptchaConfiguration `json:"captcha"`
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"`
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
}

Expand Down
28 changes: 25 additions & 3 deletions models/refresh_token.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package models

import (
"database/sql"
"time"

"github.com/gobuffalo/pop/v5"
Expand Down Expand Up @@ -57,19 +58,40 @@ func GrantRefreshTokenSwap(tx *storage.Connection, user *User, token *RefreshTok

// RevokeTokenFamily revokes all refresh tokens that descended from the provided token.
func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error {
tablename := (&pop.Model{Value: RefreshToken{}}).TableName()
err := tx.RawQuery(`
with recursive token_family as (
select id, user_id, token, revoked, parent from refresh_tokens where parent = ?
select id, user_id, token, revoked, parent from `+tablename+` where parent = ?
union
select r.id, r.user_id, r.token, r.revoked, r.parent from `+(&pop.Model{Value: RefreshToken{}}).TableName()+` r inner join token_family t on t.token = r.parent
select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent
)
update `+(&pop.Model{Value: RefreshToken{}}).TableName()+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec()
update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec()
if err != nil {
return err
}
return nil
}

// GetCurrentValidToken finds the most recent unrevoked token descending from the token provided.
func GetCurrentValidToken(tx *storage.Connection, token *RefreshToken) (*RefreshToken, error) {
tablename := (&pop.Model{Value: RefreshToken{}}).TableName()
refreshToken := &RefreshToken{}
err := tx.RawQuery(`with recursive token_family as (
select id, user_id, token, revoked, parent from `+tablename+` where parent = ?
union
select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent
)
select * from token_family where id = (select max(id) from token_family)
`, token.Token).First(refreshToken)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, RefreshTokenNotFoundError{}
}
return nil, err
}
return refreshToken, nil
}

// Logout deletes all refresh tokens for a user.
func Logout(tx *storage.Connection, instanceID uuid.UUID, id uuid.UUID) error {
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: RefreshToken{}}).TableName()+" WHERE instance_id = ? AND user_id = ?", instanceID, id).Exec()
Expand Down