From edd4540ded8c54f7f9fc7ad0d2cfb93a3cc2d522 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 9 Jan 2024 16:42:37 +0800 Subject: [PATCH] feat: clean up expired factors --- internal/api/mfa.go | 23 ++++++++++++++++------- internal/api/mfa_test.go | 27 +++++++++++++++++++++++++++ internal/conf/configuration.go | 15 ++++++++++----- internal/models/factor.go | 6 +++++- 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 434a7117c..f30a05065 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -89,25 +89,33 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { issuer = params.Issuer } - // Read from DB for certainty factors, err := models.FindFactorsByUser(a.db, user) if err != nil { return internalServerError("error validating number of factors in system").WithInternalError(err) } - - if len(factors) >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") - } - + factorCount := len(factors) numVerifiedFactors := 0 + + // Cleanup inactive factors for _, factor := range factors { + if factor.IsExpired(config.MFA.FactorExpiryDuration) { + if err := a.db.Destroy(factor); err != nil { + return internalServerError("error deleting factors").WithInternalError(err) + } + // We adjust length of factors as destroying it in the DB doesn't remove it from the array + factorCount -= 1 + } if factor.IsVerified() { numVerifiedFactors += 1 } } + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + } + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of enrolled factors reached, unenroll to continue") + return forbiddenError("Maximum number of verified factors reached, unenroll to continue") } key, err := totp.Generate(totp.GenerateOpts{ @@ -117,6 +125,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) } + var buf bytes.Buffer svgData := svg.New(&buf) qrCode, _ := qr.Encode(key.String(), qr.M, qr.Auto) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 5fe7f0bc6..e6b76f1ad 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -175,6 +175,33 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { } +func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { + // All factors are deleted when a subsequent enroll is made + ts.API.config.MFA.FactorExpiryDuration = 0 * time.Second + // Verified factor should not be deleted (Factor 1) + _ = performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + numFactors := 5 + token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn) + require.NoError(ts.T(), err) + + for i := 0; i < numFactors; i++ { + _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK) + } + + // All Factors except last factor should be expired + factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + require.NoError(ts.T(), err) + + // Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2) + _ = performChallengeFlow(ts, factors[len(factors)-1].ID, token) + + // Enroll another Factor (Factor 3) + _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK) + factors, err = models.FindFactorsByUser(ts.API.db, ts.TestUser) + require.NoError(ts.T(), err) + require.Equal(ts.T(), 3, len(factors)) +} + func (ts *MFATestSuite) TestChallengeFactor() { f := ts.TestUser.Factors[0] token := ts.generateToken(ts.TestUser, nil) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index d33cbd9c7..2bf9d3d75 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -18,6 +18,7 @@ import ( const defaultMinPasswordLength int = 6 const defaultChallengeExpiryDuration float64 = 300 +const defaultFactorExpiryDuration time.Duration = 300 * time.Second const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second var postgresNamesRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`) @@ -90,11 +91,12 @@ type JWTConfiguration struct { // MFAConfiguration holds all the MFA related Configuration type MFAConfiguration struct { - Enabled bool `default:"false"` - ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` - RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` - MaxEnrolledFactors float64 `split_words:"true" default:"10"` - MaxVerifiedFactors int `split_words:"true" default:"10"` + Enabled bool `default:"false"` + ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` + FactorExpiryDuration time.Duration `json:"factor_expiry_duration" default:"300s" split_words:"true"` + RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` + MaxEnrolledFactors float64 `split_words:"true" default:"10"` + MaxVerifiedFactors int `split_words:"true" default:"10"` } type APIConfiguration struct { @@ -669,6 +671,9 @@ func (config *GlobalConfiguration) ApplyDefaults() error { if config.MFA.ChallengeExpiryDuration < defaultChallengeExpiryDuration { config.MFA.ChallengeExpiryDuration = defaultChallengeExpiryDuration } + if config.MFA.FactorExpiryDuration < defaultFactorExpiryDuration { + config.MFA.FactorExpiryDuration = defaultFactorExpiryDuration + } if config.External.FlowStateExpiryDuration < defaultFlowStateExpiryDuration { config.External.FlowStateExpiryDuration = defaultFlowStateExpiryDuration } diff --git a/internal/models/factor.go b/internal/models/factor.go index 5af45772a..05d48f12b 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -141,7 +141,7 @@ func NewFactor(user *User, friendlyName string, factorType string, state FactorS // FindFactorsByUser returns all factors belonging to a user ordered by timestamp func FindFactorsByUser(tx *storage.Connection, user *User) ([]*Factor, error) { factors := []*Factor{} - if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil { + if err := tx.Eager().Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil { if errors.Cause(err) == sql.ErrNoRows { return factors, nil } @@ -223,3 +223,7 @@ func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error { } return nil } + +func (f *Factor) IsExpired(validityDuration time.Duration) bool { + return !f.IsVerified() && len(f.Challenge) == 0 && f.CreatedAt.Add(validityDuration).Before(time.Now()) +}