diff --git a/internal/models/factor.go b/internal/models/factor.go index b99410984..d768b1fb9 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -153,24 +153,15 @@ func FindFactorsByUser(tx *storage.Connection, user *User) ([]*Factor, error) { return factors, nil } -func FindFactorByFactorID(tx *storage.Connection, factorID uuid.UUID) (*Factor, error) { - factor, err := findFactor(tx, "id = ?", factorID) - if err != nil { +func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := conn.Find(&factor, factorID) + if err != nil && errors.Cause(err) == sql.ErrNoRows { return nil, FactorNotFoundError{} + } else if err != nil { + return nil, err } - return factor, nil -} - -func findFactor(tx *storage.Connection, query string, args ...interface{}) (*Factor, error) { - obj := &Factor{} - if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, FactorNotFoundError{} - } - return nil, errors.Wrap(err, "Database error finding factor") - } - - return obj, nil + return &factor, nil } func DeleteUnverifiedFactors(tx *storage.Connection, user *User) error { diff --git a/internal/models/factor_test.go b/internal/models/factor_test.go index eb5dccb91..c6d1f4a70 100644 --- a/internal/models/factor_test.go +++ b/internal/models/factor_test.go @@ -15,7 +15,8 @@ import ( type FactorTestSuite struct { suite.Suite - db *storage.Connection + db *storage.Connection + TestFactor *Factor } func TestFactor(t *testing.T) { @@ -32,58 +33,40 @@ func TestFactor(t *testing.T) { func (ts *FactorTestSuite) SetupTest() { TruncateAll(ts.db) -} - -func (ts *FactorTestSuite) TestFindFactorByFactorID() { - f := ts.createFactor() - n, err := FindFactorByFactorID(ts.db, f.ID) - require.NoError(ts.T(), err) - require.Equal(ts.T(), f.ID, n.ID) - _, err = FindFactorByFactorID(ts.db, uuid.Nil) - require.EqualError(ts.T(), err, FactorNotFoundError{}.Error()) -} - -func (ts *FactorTestSuite) createFactor() *Factor { user, err := NewUser("", "agenericemail@gmail.com", "secret", "test", nil) require.NoError(ts.T(), err) - - err = ts.db.Create(user) - require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) factor := NewFactor(user, "asimplename", TOTP, FactorStateUnverified, "topsecret") + require.NoError(ts.T(), ts.db.Create(factor)) + ts.TestFactor = factor +} - err = ts.db.Create(factor) +func (ts *FactorTestSuite) TestFindFactorByFactorID() { + n, err := FindFactorByFactorID(ts.db, ts.TestFactor.ID) require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.TestFactor.ID, n.ID) - return factor + _, err = FindFactorByFactorID(ts.db, uuid.Nil) + require.EqualError(ts.T(), err, FactorNotFoundError{}.Error()) } + func (ts *FactorTestSuite) TestUpdateStatus() { newFactorStatus := FactorStateVerified - u, err := NewUser("", "", "", "", nil) - require.NoError(ts.T(), err) - - f := NewFactor(u, "", TOTP, FactorStateUnverified, "some-secret") - require.NoError(ts.T(), f.UpdateStatus(ts.db, newFactorStatus)) - require.Equal(ts.T(), newFactorStatus.String(), f.Status) + require.NoError(ts.T(), ts.TestFactor.UpdateStatus(ts.db, newFactorStatus)) + require.Equal(ts.T(), newFactorStatus.String(), ts.TestFactor.Status) } func (ts *FactorTestSuite) TestUpdateFriendlyName() { - newSimpleName := "newFactorName" - u, err := NewUser("", "", "", "", nil) - require.NoError(ts.T(), err) - - f := NewFactor(u, "A1B2C3", TOTP, FactorStateUnverified, "some-secret") - require.NoError(ts.T(), f.UpdateFriendlyName(ts.db, newSimpleName)) - require.Equal(ts.T(), newSimpleName, f.FriendlyName) + newName := "newfactorname" + require.NoError(ts.T(), ts.TestFactor.UpdateFriendlyName(ts.db, newName)) + require.Equal(ts.T(), newName, ts.TestFactor.FriendlyName) } func (ts *FactorTestSuite) TestEncodedFactorDoesNotLeakSecret() { - u, err := NewUser("", "", "", "", nil) + encodedFactor, err := json.Marshal(ts.TestFactor) require.NoError(ts.T(), err) - f := NewFactor(u, "A1B2C3", TOTP, FactorStateUnverified, "some-secret") - encodedFactor, err := json.Marshal(f) - require.NoError(ts.T(), err) decodedFactor := Factor{} json.Unmarshal(encodedFactor, &decodedFactor) require.Equal(ts.T(), decodedFactor.Secret, "")