From b8d0337922c6712380f6dc74f7eac9fb71b1ae48 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Tue, 5 Mar 2024 16:56:34 +0800 Subject: [PATCH] feat: refactor PKCE FlowState to reduce duplicate code (#1446) ## What kind of change does this PR introduce? - Removes `NewFlowStateWithUserID` - it is sufficient to have one method to create a flow state - Compresses some of the PKCE checks into a single function --------- Co-authored-by: joel Co-authored-by: Stojan Dimitrovski --- internal/api/external.go | 8 ++------ internal/api/magic_link.go | 17 ++++++++++------- internal/api/pkce.go | 12 ++++++++++++ internal/api/recover.go | 15 +++++++++------ internal/api/signup.go | 15 ++++++--------- internal/api/sso.go | 8 ++------ internal/api/token_test.go | 3 +-- internal/api/user.go | 5 +++-- internal/api/verify_test.go | 4 ++-- internal/models/flow_state.go | 18 ++---------------- 10 files changed, 49 insertions(+), 56 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index 6982f2c6e..cf9da1a05 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -79,12 +79,8 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ flowType := getFlowFromChallenge(codeChallenge) flowStateID := "" - if flowType == models.PKCEFlow { - codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) - if err != nil { - return "", err - } - flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) if err != nil { return "", err } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index e1b12caaf..322cf8b8a 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -125,18 +125,21 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, make(map[string]string)) } + var flowState *models.FlowState + + if isPKCEFlow(flowType) { + flowState, err = generateFlowState(models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if err != nil { + return err + } + } err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { return terr } - - if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) - if terr != nil { - return terr - } - if terr := models.NewFlowStateWithUserID(tx, models.MagicLink.String(), params.CodeChallenge, codeChallengeMethod, models.MagicLink, &user.ID); terr != nil { + if flowState != nil { + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index c5ed5e54a..6b3166df8 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -3,6 +3,7 @@ package api import ( "regexp" + "github.com/gofrs/uuid" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" ) @@ -77,3 +78,14 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType { return models.ImplicitFlow } } + +// Should only be used with Auth Code of PKCE Flows +func generateFlowState(providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { + codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam) + if err != nil { + return nil, err + } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID) + return flowState, nil + +} diff --git a/internal/api/recover.go b/internal/api/recover.go index dcf574d1d..80cac6b89 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -56,6 +56,13 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } return internalServerError("Unable to process request").WithInternalError(err) } + var flowState *models.FlowState + if isPKCEFlow(flowType) { + flowState, err = generateFlowState(models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)) + if err != nil { + return err + } + } err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { @@ -63,12 +70,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } mailer := a.Mailer(ctx) referrer := utilities.GetReferrer(r, config) - if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) - if terr != nil { - return terr - } - if terr := models.NewFlowStateWithUserID(tx, models.Recovery.String(), params.CodeChallenge, codeChallengeMethod, models.Recovery, &(user.ID)); terr != nil { + if flowState != nil { + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/signup.go b/internal/api/signup.go index 7093fe3be..9bcac4b41 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -128,16 +128,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } - var codeChallengeMethod models.CodeChallengeMethod var err error flowType := getFlowFromChallenge(params.CodeChallenge) - if isPKCEFlow(flowType) { - if codeChallengeMethod, err = models.ParseCodeChallengeMethod(params.CodeChallengeMethod); err != nil { - return err - } - } - var user *models.User var grantParams models.GrantParams @@ -237,8 +230,12 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }); terr != nil { return terr } - if ok := isPKCEFlow(flowType); ok { - if terr := models.NewFlowStateWithUserID(tx, params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID); terr != nil { + if isPKCEFlow(flowType) { + flowState, terr := generateFlowState(params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if terr != nil { + return terr + } + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/sso.go b/internal/api/sso.go index 0b4fd8907..c1ec8aebb 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -60,12 +60,8 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { flowType := getFlowFromChallenge(params.CodeChallenge) var flowStateID *uuid.UUID flowStateID = nil - if flowType == models.PKCEFlow { - codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) - if err != nil { - return err - } - flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) if err != nil { return err } diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 0c8cc2377..b12a79a8e 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { invalidVerifier := codeVerifier + "123" codeChallenge := sha256.Sum256([]byte(codeVerifier)) challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) - flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) - require.NoError(ts.T(), err) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil) flowState.AuthCode = authCode require.NoError(ts.T(), ts.API.db.Create(flowState)) cases := []struct { diff --git a/internal/api/user.go b/internal/api/user.go index ddf497f73..e2aaedebd 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -197,11 +197,12 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { referrer := utilities.GetReferrer(r, config) flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) + flowState, terr := generateFlowState(models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) if terr != nil { return terr } - if terr := models.NewFlowStateWithUserID(tx, models.EmailChange.String(), params.CodeChallenge, codeChallengeMethod, models.EmailChange, &user.ID); terr != nil { + + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 51118f35e..c782446fb 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -661,8 +661,8 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type - err := models.NewFlowStateWithUserID(ts.API.db, c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) - require.NoError(ts.T(), err) + flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) + require.NoError(ts.T(), ts.API.db.Create(flowState)) requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 6aced0b59..04e880ec2 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -81,21 +81,7 @@ func (FlowState) TableName() string { return tableName } -func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) { - id := uuid.Must(uuid.NewV4()) - authCode := uuid.Must(uuid.NewV4()) - flowState := &FlowState{ - ID: id, - ProviderType: providerType, - CodeChallenge: codeChallenge, - CodeChallengeMethod: codeChallengeMethod.String(), - AuthCode: authCode.String(), - AuthenticationMethod: authenticationMethod.String(), - } - return flowState, nil -} - -func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error { +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState { id := uuid.Must(uuid.NewV4()) authCode := uuid.Must(uuid.NewV4()) flowState := &FlowState{ @@ -107,7 +93,7 @@ func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge AuthenticationMethod: authenticationMethod.String(), UserID: userID, } - return tx.Create(flowState) + return flowState } func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {