Skip to content

Commit

Permalink
fix: move creation of flow state into function (supabase#1470)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Follow up to supabase#1446 Moves the creation of the flow state into
`generateFlowState`

---------

Co-authored-by: joel <[email protected]>
  • Loading branch information
J0 and joel authored Mar 5, 2024
1 parent b5566e7 commit 4392a08
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 34 deletions.
5 changes: 1 addition & 4 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,10 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ

flowStateID := ""
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return "", err
}
if err := a.db.Create(flowState); err != nil {
return "", err
}
flowStateID = flowState.ID.String()
}

Expand Down
10 changes: 1 addition & 9 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,9 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, make(map[string]string))
}
var flowState *models.FlowState

if isPKCEFlow(flowType) {
flowState, err = generateFlowState(models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
return err
}
}
Expand All @@ -138,12 +136,6 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}

mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
Expand Down
5 changes: 4 additions & 1 deletion internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
}

// Should only be used with Auth Code of PKCE Flows
func generateFlowState(providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID)
if err := tx.Create(flowState); err != nil {
return nil, err
}
return flowState, nil

}
9 changes: 1 addition & 8 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
return internalServerError("Unable to process request").WithInternalError(err)
}
var flowState *models.FlowState
if isPKCEFlow(flowType) {
flowState, err = generateFlowState(models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
if _, err := generateFlowState(db, models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)); err != nil {
return err
}
}
Expand All @@ -70,11 +68,6 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
externalURL := getExternalHost(ctx)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
})
Expand Down
5 changes: 1 addition & 4 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return terr
}
if isPKCEFlow(flowType) {
flowState, terr := generateFlowState(params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
_, terr := generateFlowState(tx, params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
externalURL := getExternalHost(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil {
Expand Down
5 changes: 1 addition & 4 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,10 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
var flowStateID *uuid.UUID
flowStateID = nil
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}
if err := a.db.Create(flowState); err != nil {
return err
}
flowStateID = &flowState.ID
}

Expand Down
5 changes: 1 addition & 4 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,11 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
referrer := utilities.GetReferrer(r, config)
flowType := getFlowFromChallenge(params.CodeChallenge)
if isPKCEFlow(flowType) {
flowState, terr := generateFlowState(models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
_, terr := generateFlowState(tx, models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}

if terr := tx.Create(flowState); terr != nil {
return terr
}
}
externalURL := getExternalHost(ctx)
if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil {
Expand Down

0 comments on commit 4392a08

Please sign in to comment.