Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor mfa challenge and tests #1469

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
ipAddress := utilities.GetIPAddress(r)
challenge := models.NewChallenge(factor, ipAddress)

err := a.db.Transaction(func(tx *storage.Connection) error {
if err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(challenge); terr != nil {
return terr
}
Expand All @@ -181,8 +181,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
return terr
}
return nil
})
if err != nil {
}); err != nil {
return err
}

Expand All @@ -209,11 +208,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
return internalServerError(InvalidFactorOwnerErrorMessage)
}

challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID)
if err != nil {
if models.IsNotFoundError(err) {
return notFoundError(err.Error())
}
challenge, err := models.FindChallengeByID(a.db, params.ChallengeID)
if err != nil && models.IsNotFoundError(err) {
return notFoundError(err.Error())
} else if err != nil {
J0 marked this conversation as resolved.
Show resolved Hide resolved
return internalServerError("Database error finding Challenge").WithInternalError(err)
}

Expand All @@ -222,15 +220,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
}

if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) {
err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Destroy(challenge); terr != nil {
return internalServerError("Database error deleting challenge").WithInternalError(terr)
}

return nil
})
if err != nil {
return err
if err := a.db.Destroy(challenge); err != nil {
return internalServerError("Database error deleting challenge").WithInternalError(err)
}
return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID)
}
Expand Down
51 changes: 21 additions & 30 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
if c.friendlyName != "" && c.expectedCode == http.StatusOK {
require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName)
}

if w.Code == http.StatusOK {
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
Expand Down Expand Up @@ -218,6 +219,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{})
require.NoError(ts.T(), err)

sharedSecret := ts.TestOTPKey.Secret()
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
f := factors[0]
Expand All @@ -226,11 +228,11 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

token := ts.generateAAL1Token(ts.TestUser, r.SessionId)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer)
testIPAddress := utilities.GetIPAddress(req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

testIPAddress := utilities.GetIPAddress(req)
c := models.NewChallenge(f, testIPAddress)
require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge")
if !v.validChallenge {
Expand Down Expand Up @@ -263,7 +265,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}
if !v.validChallenge {
// Ensure invalid challenges are deleted
_, err := models.FindChallengeByChallengeID(ts.API.db, c.ID)
_, err := models.FindChallengeByID(ts.API.db, c.ID)
require.EqualError(ts.T(), err, models.ChallengeNotFoundError{}.Error())
}
})
Expand Down Expand Up @@ -293,7 +295,6 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
Expand All @@ -303,15 +304,11 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s/", f.ID), &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), v.expectedHTTPCode, w.Code)

if v.expectedHTTPCode == http.StatusOK {
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)
_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
Expand All @@ -333,10 +330,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
"factor_id": f.ID,
}))

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), http.StatusOK, w.Code)

_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
Expand Down Expand Up @@ -437,21 +431,24 @@ func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requir

func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder {
var buffer bytes.Buffer
w := httptest.NewRecorder()
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer}))
w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer)
require.Equal(ts.T(), expectedCode, w.Code)
return w
}

req := httptest.NewRequest(http.MethodPost, "http://localhost/factors/", &buffer)
func ServeAuthenticatedRequest(ts *MFATestSuite, method, path, token string, buffer bytes.Buffer) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req := httptest.NewRequest(method, path, &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), expectedCode, w.Code)
return w
}

func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, requireStatusOK bool) *httptest.ResponseRecorder {
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()
var buffer bytes.Buffer

conn, err := pgx.Connect(context.Background(), ts.API.db.URL())
require.NoError(ts.T(), err)
Expand All @@ -465,28 +462,22 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token
code, err := totp.GenerateCode(totpSecret, time.Now().UTC())
require.NoError(ts.T(), err)

require.NoError(ts.T(), json.NewEncoder(&verifyBuffer).Encode(map[string]interface{}{
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"challenge_id": challengeID,
"code": code,
}))
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
y := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), token, buffer)

if requireStatusOK {
require.Equal(ts.T(), http.StatusOK, y.Code)
}
return y
}

func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder {
var challengeBuffer bytes.Buffer
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(w, req)
var buffer bytes.Buffer
w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer)
require.Equal(ts.T(), http.StatusOK, w.Code)
return w

Expand Down
22 changes: 7 additions & 15 deletions internal/models/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ func NewChallenge(factor *Factor, ipAddress string) *Challenge {
return challenge
}

func FindChallengeByChallengeID(tx *storage.Connection, challengeID uuid.UUID) (*Challenge, error) {
challenge, err := findChallenge(tx, "id = ?", challengeID)
if err != nil {
func FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) {
var challenge Challenge
err := conn.Find(&challenge, challengeID)
if err != nil && errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
} else if err != nil {
return nil, err
}
return challenge, nil
return &challenge, nil
}

// Update the verification timestamp
Expand All @@ -55,14 +58,3 @@ func (c *Challenge) HasExpired(expiryDuration float64) bool {
func (c *Challenge) GetExpiryTime(expiryDuration float64) time.Time {
return c.CreatedAt.Add(time.Second * time.Duration(expiryDuration))
}

func findChallenge(tx *storage.Connection, query string, args ...interface{}) (*Challenge, error) {
obj := &Challenge{}
if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
}
return nil, errors.Wrap(err, "error finding challenge")
}
return obj, nil
}
Loading