From 6c76f21cee5dbef0562c37df6a546939affb2f8d Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Wed, 6 Mar 2024 01:10:39 +0800 Subject: [PATCH] fix: refactor mfa challenge and tests (#1469) ## What kind of change does this PR introduce? Does a few refactors: - Create a dedicated function in tests for making authenticated request - remove `findChallenge` as it is not used - rename `FindChallengeByChallengeID` to `FindChallengeByID` Co-authored-by: joel --- internal/api/mfa.go | 25 ++++++------------ internal/api/mfa_test.go | 51 +++++++++++++++--------------------- internal/models/challenge.go | 22 +++++----------- 3 files changed, 36 insertions(+), 62 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 256d3e3d3..9b2253185 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -170,7 +170,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { ipAddress := utilities.GetIPAddress(r) challenge := models.NewChallenge(factor, ipAddress) - err := a.db.Transaction(func(tx *storage.Connection) error { + if err := a.db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(challenge); terr != nil { return terr } @@ -181,8 +181,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { return terr } return nil - }) - if err != nil { + }); err != nil { return err } @@ -209,11 +208,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { return internalServerError(InvalidFactorOwnerErrorMessage) } - challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } + challenge, err := models.FindChallengeByID(a.db, params.ChallengeID) + if err != nil && models.IsNotFoundError(err) { + return notFoundError(err.Error()) + } else if err != nil { return internalServerError("Database error finding Challenge").WithInternalError(err) } @@ -222,15 +220,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { - err := a.db.Transaction(func(tx *storage.Connection) error { - if terr := tx.Destroy(challenge); terr != nil { - return internalServerError("Database error deleting challenge").WithInternalError(terr) - } - - return nil - }) - if err != nil { - return err + if err := a.db.Destroy(challenge); err != nil { + return internalServerError("Database error deleting challenge").WithInternalError(err) } return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) } diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index df7b1921b..32fd2f8c7 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -150,6 +150,7 @@ func (ts *MFATestSuite) TestEnrollFactor() { if c.friendlyName != "" && c.expectedCode == http.StatusOK { require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName) } + if w.Code == http.StatusOK { enrollResp := EnrollFactorResponse{} require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) @@ -218,6 +219,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { var buffer bytes.Buffer r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{}) require.NoError(ts.T(), err) + sharedSecret := ts.TestOTPKey.Secret() factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) f := factors[0] @@ -226,11 +228,11 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor") token := ts.generateAAL1Token(ts.TestUser, r.SessionId) - w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer) - testIPAddress := utilities.GetIPAddress(req) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + testIPAddress := utilities.GetIPAddress(req) c := models.NewChallenge(f, testIPAddress) require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge") if !v.validChallenge { @@ -263,7 +265,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { } if !v.validChallenge { // Ensure invalid challenges are deleted - _, err := models.FindChallengeByChallengeID(ts.API.db, c.ID) + _, err := models.FindChallengeByID(ts.API.db, c.ID) require.EqualError(ts.T(), err, models.ChallengeNotFoundError{}.Error()) } }) @@ -293,7 +295,6 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { if v.isAAL2 { ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String()) } - // Create Session to test behaviour which downgrades other sessions factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) require.NoError(ts.T(), err, "error finding factors") @@ -303,15 +304,11 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor") token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s/", f.ID), &buffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - ts.API.handler.ServeHTTP(w, req) + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) require.Equal(ts.T(), v.expectedHTTPCode, w.Code) if v.expectedHTTPCode == http.StatusOK { - _, err = models.FindFactorByFactorID(ts.API.db, f.ID) + _, err := models.FindFactorByFactorID(ts.API.db, f.ID) require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) require.Equal(ts.T(), models.AAL1.String(), session.GetAAL()) @@ -333,10 +330,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { "factor_id": f.ID, })) - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), &buffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - ts.API.handler.ServeHTTP(w, req) + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) require.Equal(ts.T(), http.StatusOK, w.Code) _, err := models.FindFactorByFactorID(ts.API.db, f.ID) @@ -437,21 +431,24 @@ func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requir func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder { var buffer bytes.Buffer - w := httptest.NewRecorder() require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer})) + w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer) + require.Equal(ts.T(), expectedCode, w.Code) + return w +} - req := httptest.NewRequest(http.MethodPost, "http://localhost/factors/", &buffer) +func ServeAuthenticatedRequest(ts *MFATestSuite, method, path, token string, buffer bytes.Buffer) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + req := httptest.NewRequest(method, path, &buffer) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("Content-Type", "application/json") ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), expectedCode, w.Code) return w } func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, requireStatusOK bool) *httptest.ResponseRecorder { - var verifyBuffer bytes.Buffer - y := httptest.NewRecorder() + var buffer bytes.Buffer conn, err := pgx.Connect(context.Background(), ts.API.db.URL()) require.NoError(ts.T(), err) @@ -465,15 +462,13 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token code, err := totp.GenerateCode(totpSecret, time.Now().UTC()) require.NoError(ts.T(), err) - require.NoError(ts.T(), json.NewEncoder(&verifyBuffer).Encode(map[string]interface{}{ + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ "challenge_id": challengeID, "code": code, })) - req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - req.Header.Set("Content-Type", "application/json") - ts.API.handler.ServeHTTP(y, req) + y := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), token, buffer) + if requireStatusOK { require.Equal(ts.T(), http.StatusOK, y.Code) } @@ -481,12 +476,8 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token } func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder { - var challengeBuffer bytes.Buffer - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - req.Header.Set("Content-Type", "application/json") - ts.API.handler.ServeHTTP(w, req) + var buffer bytes.Buffer + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) require.Equal(ts.T(), http.StatusOK, w.Code) return w diff --git a/internal/models/challenge.go b/internal/models/challenge.go index 99758e63b..d52132aaa 100644 --- a/internal/models/challenge.go +++ b/internal/models/challenge.go @@ -33,12 +33,15 @@ func NewChallenge(factor *Factor, ipAddress string) *Challenge { return challenge } -func FindChallengeByChallengeID(tx *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { - challenge, err := findChallenge(tx, "id = ?", challengeID) - if err != nil { +func FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { + var challenge Challenge + err := conn.Find(&challenge, challengeID) + if err != nil && errors.Cause(err) == sql.ErrNoRows { return nil, ChallengeNotFoundError{} + } else if err != nil { + return nil, err } - return challenge, nil + return &challenge, nil } // Update the verification timestamp @@ -55,14 +58,3 @@ func (c *Challenge) HasExpired(expiryDuration float64) bool { func (c *Challenge) GetExpiryTime(expiryDuration float64) time.Time { return c.CreatedAt.Add(time.Second * time.Duration(expiryDuration)) } - -func findChallenge(tx *storage.Connection, query string, args ...interface{}) (*Challenge, error) { - obj := &Challenge{} - if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, ChallengeNotFoundError{} - } - return nil, errors.Wrap(err, "error finding challenge") - } - return obj, nil -}