diff --git a/internal/api/external.go b/internal/api/external.go index cf9da1a05..facb8ce91 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -80,13 +80,10 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ flowStateID := "" if isPKCEFlow(flowType) { - flowState, err := generateFlowState(providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) + flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) if err != nil { return "", err } - if err := a.db.Create(flowState); err != nil { - return "", err - } flowStateID = flowState.ID.String() } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index 322cf8b8a..ddd3dba37 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -125,11 +125,9 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, make(map[string]string)) } - var flowState *models.FlowState if isPKCEFlow(flowType) { - flowState, err = generateFlowState(models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) - if err != nil { + if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil { return err } } @@ -138,12 +136,6 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { return terr } - if flowState != nil { - if terr := tx.Create(flowState); terr != nil { - return terr - } - } - mailer := a.Mailer(ctx) referrer := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) diff --git a/internal/api/pkce.go b/internal/api/pkce.go index 6b3166df8..7c24b6301 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -80,12 +80,15 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType { } // Should only be used with Auth Code of PKCE Flows -func generateFlowState(providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { +func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam) if err != nil { return nil, err } flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID) + if err := tx.Create(flowState); err != nil { + return nil, err + } return flowState, nil } diff --git a/internal/api/recover.go b/internal/api/recover.go index 80cac6b89..77e3c068d 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -56,10 +56,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } return internalServerError("Unable to process request").WithInternalError(err) } - var flowState *models.FlowState if isPKCEFlow(flowType) { - flowState, err = generateFlowState(models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)) - if err != nil { + if _, err := generateFlowState(db, models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)); err != nil { return err } } @@ -70,11 +68,6 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } mailer := a.Mailer(ctx) referrer := utilities.GetReferrer(r, config) - if flowState != nil { - if terr := tx.Create(flowState); terr != nil { - return terr - } - } externalURL := getExternalHost(ctx) return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType) }) diff --git a/internal/api/signup.go b/internal/api/signup.go index 9bcac4b41..3d7a19000 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -231,13 +231,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return terr } if isPKCEFlow(flowType) { - flowState, terr := generateFlowState(params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + _, terr := generateFlowState(tx, params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) if terr != nil { return terr } - if terr := tx.Create(flowState); terr != nil { - return terr - } } externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { diff --git a/internal/api/sso.go b/internal/api/sso.go index c1ec8aebb..08ca4c616 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -61,13 +61,10 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { var flowStateID *uuid.UUID flowStateID = nil if isPKCEFlow(flowType) { - flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) + flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) if err != nil { return err } - if err := a.db.Create(flowState); err != nil { - return err - } flowStateID = &flowState.ID } diff --git a/internal/api/user.go b/internal/api/user.go index e2aaedebd..723521c17 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -197,14 +197,11 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { referrer := utilities.GetReferrer(r, config) flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { - flowState, terr := generateFlowState(models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + _, terr := generateFlowState(tx, models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) if terr != nil { return terr } - if terr := tx.Create(flowState); terr != nil { - return terr - } } externalURL := getExternalHost(ctx) if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil {