Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unlink identity bugs #1475

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: remove phone from user if phone identity is unlinked
kangmingtay committed Mar 7, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 505f9d8fc92f80e95770ccf1fcb73ff7ff78e0a1
20 changes: 16 additions & 4 deletions internal/api/identity.go
Original file line number Diff line number Diff line change
@@ -60,11 +60,23 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
if terr := tx.Destroy(identityToBeDeleted); terr != nil {
return internalServerError("Database error deleting identity").WithInternalError(terr)
}
if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil {
if models.IsUniqueConstraintViolatedError(terr) {
return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr)

switch identityToBeDeleted.Provider {
case "phone":
user.PhoneConfirmedAt = nil
if terr := user.SetPhone(tx, ""); terr != nil {
return internalServerError("Database error updating user phone").WithInternalError(terr)
}
if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil {
return internalServerError("Database error updating user phone").WithInternalError(terr)
}
default:
if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil {
if models.IsUniqueConstraintViolatedError(terr) {
return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr)
}
return internalServerError("Database error updating user email").WithInternalError(terr)
}
return internalServerError("Database error updating user email").WithInternalError(terr)
}
if terr := user.UpdateAppMetaDataProviders(tx); terr != nil {
return internalServerError("Database error updating user providers").WithInternalError(terr)
138 changes: 136 additions & 2 deletions internal/api/identity_test.go
Original file line number Diff line number Diff line change
@@ -2,10 +2,13 @@ package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/gofrs/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/api/provider"
@@ -34,9 +37,10 @@ func (ts *IdentityTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

// Create user
u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil)
u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
require.NoError(ts.T(), u.Confirm(ts.API.db))

// Create identity
i, err := models.NewIdentity(u, "email", map[string]interface{}{
@@ -45,10 +49,31 @@ func (ts *IdentityTestSuite) SetupTest() {
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i))

// Create user with 2 identities
u, err = models.NewUser("123456789", "[email protected]", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
require.NoError(ts.T(), u.Confirm(ts.API.db))
require.NoError(ts.T(), u.ConfirmPhone(ts.API.db))

i, err = models.NewIdentity(u, "email", map[string]interface{}{
"sub": u.ID.String(),
"email": u.GetEmail(),
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i))

i2, err := models.NewIdentity(u, "phone", map[string]interface{}{
"sub": u.ID.String(),
"phone": u.GetPhone(),
})
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(i2))
}

func (ts *IdentityTestSuite) TestLinkIdentityToUser() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
ctx := withTargetUser(context.Background(), u)

@@ -79,3 +104,112 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() {
require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked"))
require.Nil(ts.T(), u)
}

func (ts *IdentityTestSuite) TestUnlinkIdentityError() {
ts.Config.Security.ManualLinkingEnabled = true
userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
cases := []struct {
desc string
user *models.User
identityId uuid.UUID
expectedError *HTTPError
}{
{
desc: "User must have at least 1 identity after unlinking",
user: userWithOneIdentity,
identityId: userWithOneIdentity.Identities[0].ID,
expectedError: badRequestError("User must have at least 1 identity after unlinking"),
},
{
desc: "Identity doesn't exist",
user: userWithTwoIdentities,
identityId: uuid.Must(uuid.NewV4()),
expectedError: badRequestError("Identity doesn't exist"),
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, c.user, nil, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedError.Code, w.Code)

var data HTTPError
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedError.Message, data.Message)
})
}
}

func (ts *IdentityTestSuite) TestUnlinkIdentity() {
ts.Config.Security.ManualLinkingEnabled = true

// we want to test 2 cases here: unlinking a phone identity and email identity from a user
cases := []struct {
desc string
// the provider to be unlinked
provider string
// the remaining provider that should be linked to the user
providerRemaining string
}{
{
desc: "Unlink phone identity successfully",
provider: "phone",
providerRemaining: "email",
},
{
desc: "Unlink email identity successfully",
provider: "email",
providerRemaining: "phone",
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
// teardown and reset the state of the db to prevent running into errors
ts.SetupTest()
u, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider)
require.NoError(ts.T(), err)

token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)

// sanity checks
u, err = models.FindUserByID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
require.Len(ts.T(), u.Identities, 1)
require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining)

// conditional checks depending on the provider that was unlinked
switch c.provider {
case "phone":
require.Equal(ts.T(), "", u.GetPhone())
require.Nil(ts.T(), u.PhoneConfirmedAt)
case "email":
require.Equal(ts.T(), "", u.GetEmail())
require.Nil(ts.T(), u.EmailConfirmedAt)
}

// user still has a phone / email identity linked so it should not be unconfirmed
require.NotNil(ts.T(), u.ConfirmedAt)
})
}

}