Skip to content

Commit

Permalink
fix: refactor mfa challenge and tests (supabase#1469)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Does a few refactors:
- Create a dedicated function in tests for making authenticated request
- remove `findChallenge` as it is not used 
- rename `FindChallengeByChallengeID` to `FindChallengeByID`

Co-authored-by: joel <[email protected]>
  • Loading branch information
J0 and joel authored Mar 5, 2024
1 parent 4392a08 commit 6c76f21
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 62 deletions.
25 changes: 8 additions & 17 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
ipAddress := utilities.GetIPAddress(r)
challenge := models.NewChallenge(factor, ipAddress)

err := a.db.Transaction(func(tx *storage.Connection) error {
if err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(challenge); terr != nil {
return terr
}
Expand All @@ -181,8 +181,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
return terr
}
return nil
})
if err != nil {
}); err != nil {
return err
}

Expand All @@ -209,11 +208,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
return internalServerError(InvalidFactorOwnerErrorMessage)
}

challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID)
if err != nil {
if models.IsNotFoundError(err) {
return notFoundError(err.Error())
}
challenge, err := models.FindChallengeByID(a.db, params.ChallengeID)
if err != nil && models.IsNotFoundError(err) {
return notFoundError(err.Error())
} else if err != nil {
return internalServerError("Database error finding Challenge").WithInternalError(err)
}

Expand All @@ -222,15 +220,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
}

if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) {
err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Destroy(challenge); terr != nil {
return internalServerError("Database error deleting challenge").WithInternalError(terr)
}

return nil
})
if err != nil {
return err
if err := a.db.Destroy(challenge); err != nil {
return internalServerError("Database error deleting challenge").WithInternalError(err)
}
return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID)
}
Expand Down
51 changes: 21 additions & 30 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
if c.friendlyName != "" && c.expectedCode == http.StatusOK {
require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName)
}

if w.Code == http.StatusOK {
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
Expand Down Expand Up @@ -218,6 +219,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{})
require.NoError(ts.T(), err)

sharedSecret := ts.TestOTPKey.Secret()
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
f := factors[0]
Expand All @@ -226,11 +228,11 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

token := ts.generateAAL1Token(ts.TestUser, r.SessionId)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer)
testIPAddress := utilities.GetIPAddress(req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

testIPAddress := utilities.GetIPAddress(req)
c := models.NewChallenge(f, testIPAddress)
require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge")
if !v.validChallenge {
Expand Down Expand Up @@ -263,7 +265,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}
if !v.validChallenge {
// Ensure invalid challenges are deleted
_, err := models.FindChallengeByChallengeID(ts.API.db, c.ID)
_, err := models.FindChallengeByID(ts.API.db, c.ID)
require.EqualError(ts.T(), err, models.ChallengeNotFoundError{}.Error())
}
})
Expand Down Expand Up @@ -293,7 +295,6 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
Expand All @@ -303,15 +304,11 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s/", f.ID), &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), v.expectedHTTPCode, w.Code)

if v.expectedHTTPCode == http.StatusOK {
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)
_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
Expand All @@ -333,10 +330,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
"factor_id": f.ID,
}))

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), http.StatusOK, w.Code)

_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
Expand Down Expand Up @@ -437,21 +431,24 @@ func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requir

func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder {
var buffer bytes.Buffer
w := httptest.NewRecorder()
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer}))
w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer)
require.Equal(ts.T(), expectedCode, w.Code)
return w
}

req := httptest.NewRequest(http.MethodPost, "http://localhost/factors/", &buffer)
func ServeAuthenticatedRequest(ts *MFATestSuite, method, path, token string, buffer bytes.Buffer) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req := httptest.NewRequest(method, path, &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), expectedCode, w.Code)
return w
}

func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, requireStatusOK bool) *httptest.ResponseRecorder {
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()
var buffer bytes.Buffer

conn, err := pgx.Connect(context.Background(), ts.API.db.URL())
require.NoError(ts.T(), err)
Expand All @@ -465,28 +462,22 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token
code, err := totp.GenerateCode(totpSecret, time.Now().UTC())
require.NoError(ts.T(), err)

require.NoError(ts.T(), json.NewEncoder(&verifyBuffer).Encode(map[string]interface{}{
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"challenge_id": challengeID,
"code": code,
}))
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
y := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), token, buffer)

if requireStatusOK {
require.Equal(ts.T(), http.StatusOK, y.Code)
}
return y
}

func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder {
var challengeBuffer bytes.Buffer
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(w, req)
var buffer bytes.Buffer
w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer)
require.Equal(ts.T(), http.StatusOK, w.Code)
return w

Expand Down
22 changes: 7 additions & 15 deletions internal/models/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ func NewChallenge(factor *Factor, ipAddress string) *Challenge {
return challenge
}

func FindChallengeByChallengeID(tx *storage.Connection, challengeID uuid.UUID) (*Challenge, error) {
challenge, err := findChallenge(tx, "id = ?", challengeID)
if err != nil {
func FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) {
var challenge Challenge
err := conn.Find(&challenge, challengeID)
if err != nil && errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
} else if err != nil {
return nil, err
}
return challenge, nil
return &challenge, nil
}

// Update the verification timestamp
Expand All @@ -55,14 +58,3 @@ func (c *Challenge) HasExpired(expiryDuration float64) bool {
func (c *Challenge) GetExpiryTime(expiryDuration float64) time.Time {
return c.CreatedAt.Add(time.Second * time.Duration(expiryDuration))
}

func findChallenge(tx *storage.Connection, query string, args ...interface{}) (*Challenge, error) {
obj := &Challenge{}
if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
}
return nil, errors.Wrap(err, "error finding challenge")
}
return obj, nil
}

0 comments on commit 6c76f21

Please sign in to comment.