diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 9b2253185d..6bdb3e5962 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -64,12 +64,16 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { session := getSession(ctx) config := a.config + if session == nil || user == nil { + return internalServerError("A valid session and a registered user are required to enroll a factor") + } + params := &EnrollFactorParams{} if err := retrieveRequestParams(r, params); err != nil { return err } - issuer := "" + issuer := "" if params.FactorType != models.TOTP { return badRequestError("factor_type needs to be totp") } @@ -84,25 +88,26 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { issuer = params.Issuer } - // Read from DB for certainty - factors, err := models.FindFactorsByUser(a.db, user) - if err != nil { - return internalServerError("error validating number of factors in system").WithInternalError(err) - } + factors := user.Factors - if len(factors) >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + factorCount := len(factors) + numVerifiedFactors := 0 + if err := models.DeleteExpiredFactors(a.db, config.MFA.FactorExpiryDuration); err != nil { + return err } - numVerifiedFactors := 0 for _, factor := range factors { if factor.IsVerified() { numVerifiedFactors += 1 } } + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + } + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of enrolled factors reached, unenroll to continue") + return forbiddenError("Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors > 0 && !session.IsAAL2() { @@ -116,6 +121,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) } + var buf bytes.Buffer svgData := svg.New(&buf) qrCode, _ := qr.Encode(key.String(), qr.H, qr.Auto) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 32fd2f8c78..bb3c919681 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -180,6 +180,34 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { } +func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { + // All factors are deleted when a subsequent enroll is made + ts.API.config.MFA.FactorExpiryDuration = 0 * time.Second + // Verified factor should not be deleted (Factor 1) + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + numFactors := 5 + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + token := accessTokenResp.Token + for i := 0; i < numFactors; i++ { + _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK) + } + + // All Factors except last factor should be expired + factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + require.NoError(ts.T(), err) + + // Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2) + _ = performChallengeFlow(ts, factors[len(factors)-1].ID, token) + + // Enroll another Factor (Factor 3) + _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK) + factors, err = models.FindFactorsByUser(ts.API.db, ts.TestUser) + require.NoError(ts.T(), err) + require.Equal(ts.T(), 3, len(factors)) +} + func (ts *MFATestSuite) TestChallengeFactor() { f := ts.TestUser.Factors[0] token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) @@ -431,7 +459,7 @@ func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requir func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder { var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer})) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(EnrollFactorParams{FriendlyName: friendlyName, FactorType: factorType, Issuer: issuer})) w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer) require.Equal(ts.T(), expectedCode, w.Code) return w diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index d235987954..31fb8b22d3 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -18,6 +18,7 @@ import ( const defaultMinPasswordLength int = 6 const defaultChallengeExpiryDuration float64 = 300 +const defaultFactorExpiryDuration time.Duration = 300 * time.Second const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second // See: https://www.postgresql.org/docs/7.0/syntax525.htm @@ -102,11 +103,12 @@ type JWTConfiguration struct { // MFAConfiguration holds all the MFA related Configuration type MFAConfiguration struct { - Enabled bool `default:"false"` - ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` - RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` - MaxEnrolledFactors float64 `split_words:"true" default:"10"` - MaxVerifiedFactors int `split_words:"true" default:"10"` + Enabled bool `default:"false"` + ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` + FactorExpiryDuration time.Duration `json:"factor_expiry_duration" default:"300s" split_words:"true"` + RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` + MaxEnrolledFactors float64 `split_words:"true" default:"10"` + MaxVerifiedFactors int `split_words:"true" default:"10"` } type APIConfiguration struct { @@ -711,6 +713,9 @@ func (config *GlobalConfiguration) ApplyDefaults() error { if config.MFA.ChallengeExpiryDuration < defaultChallengeExpiryDuration { config.MFA.ChallengeExpiryDuration = defaultChallengeExpiryDuration } + if config.MFA.FactorExpiryDuration < defaultFactorExpiryDuration { + config.MFA.FactorExpiryDuration = defaultFactorExpiryDuration + } if config.External.FlowStateExpiryDuration < defaultFlowStateExpiryDuration { config.External.FlowStateExpiryDuration = defaultFlowStateExpiryDuration } diff --git a/internal/models/factor.go b/internal/models/factor.go index d768b1fb91..24afb77415 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -217,3 +217,17 @@ func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error { } return nil } + +func DeleteExpiredFactors(tx *storage.Connection, validityDuration time.Duration) error { + totalSeconds := int64(validityDuration / time.Second) + validityInterval := fmt.Sprintf("interval '%d seconds'", totalSeconds) + + factorTable := (&pop.Model{Value: Factor{}}).TableName() + challengeTable := (&pop.Model{Value: Challenge{}}).TableName() + + query := fmt.Sprintf(`delete from %q where status != 'verified' and not exists (select * from %q where %q.id = %q.factor_id ) and created_at + %s < current_timestamp;`, factorTable, challengeTable, factorTable, challengeTable, validityInterval) + if err := tx.RawQuery(query).Exec(); err != nil { + return err + } + return nil +}