Skip to content

Commit

Permalink
fix: unlink identity bugs (supabase#1475)
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay authored and LashaJini committed Nov 13, 2024
1 parent 2ef26b3 commit b419f0b
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 6 deletions.
20 changes: 16 additions & 4 deletions internal/api/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,23 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
if terr := tx.Destroy(identityToBeDeleted); terr != nil {
return internalServerError("Database error deleting identity").WithInternalError(terr)
}
if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil {
if models.IsUniqueConstraintViolatedError(terr) {
return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr)

switch identityToBeDeleted.Provider {
case "phone":
user.PhoneConfirmedAt = nil
if terr := user.SetPhone(tx, ""); terr != nil {
return internalServerError("Database error updating user phone").WithInternalError(terr)
}
if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil {
return internalServerError("Database error updating user phone").WithInternalError(terr)
}
default:
if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil {
if models.IsUniqueConstraintViolatedError(terr) {
return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr)
}
return internalServerError("Database error updating user email").WithInternalError(terr)
}
return internalServerError("Database error updating user email").WithInternalError(terr)
}
if terr := user.UpdateAppMetaDataProviders(tx); terr != nil {
return internalServerError("Database error updating user providers").WithInternalError(terr)
Expand Down
138 changes: 136 additions & 2 deletions internal/api/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/gofrs/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/api/provider"
Expand Down Expand Up @@ -34,9 +37,10 @@ func (ts *IdentityTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

// Create user
u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil)
u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
require.NoError(ts.T(), u.Confirm(ts.API.db))

// Create identity
i, err := models.NewIdentity(u, "email", map[string]interface{}{
Expand All @@ -45,10 +49,31 @@ func (ts *IdentityTestSuite) SetupTest() {
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i))

// Create user with 2 identities
u, err = models.NewUser("123456789", "[email protected]", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
require.NoError(ts.T(), u.Confirm(ts.API.db))
require.NoError(ts.T(), u.ConfirmPhone(ts.API.db))

i, err = models.NewIdentity(u, "email", map[string]interface{}{
"sub": u.ID.String(),
"email": u.GetEmail(),
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i))

i2, err := models.NewIdentity(u, "phone", map[string]interface{}{
"sub": u.ID.String(),
"phone": u.GetPhone(),
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i2))
}

func (ts *IdentityTestSuite) TestLinkIdentityToUser() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
ctx := withTargetUser(context.Background(), u)

Expand Down Expand Up @@ -79,3 +104,112 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() {
require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked"))
require.Nil(ts.T(), u)
}

func (ts *IdentityTestSuite) TestUnlinkIdentityError() {
ts.Config.Security.ManualLinkingEnabled = true
userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
cases := []struct {
desc string
user *models.User
identityId uuid.UUID
expectedError *HTTPError
}{
{
desc: "User must have at least 1 identity after unlinking",
user: userWithOneIdentity,
identityId: userWithOneIdentity.Identities[0].ID,
expectedError: badRequestError("User must have at least 1 identity after unlinking"),
},
{
desc: "Identity doesn't exist",
user: userWithTwoIdentities,
identityId: uuid.Must(uuid.NewV4()),
expectedError: badRequestError("Identity doesn't exist"),
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, c.user, nil, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedError.Code, w.Code)

var data HTTPError
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedError.Message, data.Message)
})
}
}

func (ts *IdentityTestSuite) TestUnlinkIdentity() {
ts.Config.Security.ManualLinkingEnabled = true

// we want to test 2 cases here: unlinking a phone identity and email identity from a user
cases := []struct {
desc string
// the provider to be unlinked
provider string
// the remaining provider that should be linked to the user
providerRemaining string
}{
{
desc: "Unlink phone identity successfully",
provider: "phone",
providerRemaining: "email",
},
{
desc: "Unlink email identity successfully",
provider: "email",
providerRemaining: "phone",
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
// teardown and reset the state of the db to prevent running into errors
ts.SetupTest()
u, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider)
require.NoError(ts.T(), err)

token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)

// sanity checks
u, err = models.FindUserByID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
require.Len(ts.T(), u.Identities, 1)
require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining)

// conditional checks depending on the provider that was unlinked
switch c.provider {
case "phone":
require.Equal(ts.T(), "", u.GetPhone())
require.Nil(ts.T(), u.PhoneConfirmedAt)
case "email":
require.Equal(ts.T(), "", u.GetEmail())
require.Nil(ts.T(), u.EmailConfirmedAt)
}

// user still has a phone / email identity linked so it should not be unconfirmed
require.NotNil(ts.T(), u.ConfirmedAt)
})
}

}
6 changes: 6 additions & 0 deletions internal/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ func (u *User) UpdateUserEmailFromIdentities(tx *storage.Connection) error {
if terr := u.SetEmail(tx, primaryIdentity.GetEmail()); terr != nil {
return terr
}
if primaryIdentity.GetEmail() == "" {
u.EmailConfirmedAt = nil
if terr := tx.UpdateOnly(u, "email_confirmed_at"); terr != nil {
return terr
}
}
return nil
}

Expand Down

0 comments on commit b419f0b

Please sign in to comment.