Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor PKCE FlowState to reduce duplicate code #1446

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
flowType := getFlowFromChallenge(codeChallenge)

flowStateID := ""
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return "", err
}
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return "", err
}

J0 marked this conversation as resolved.
Show resolved Hide resolved
if err := a.db.Create(flowState); err != nil {
return "", err
}
Expand Down
17 changes: 10 additions & 7 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,21 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, make(map[string]string))
}
var flowState *models.FlowState

if isPKCEFlow(flowType) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
flowState, err = generateFlowState(models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}

if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.MagicLink.String(), params.CodeChallenge, codeChallengeMethod, models.MagicLink, &user.ID); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
12 changes: 12 additions & 0 deletions internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"regexp"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
Expand Down Expand Up @@ -77,3 +78,14 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
return models.ImplicitFlow
}
}

// Should only be used with Auth Code of PKCE Flows
func generateFlowState(providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID)
return flowState, nil

}
16 changes: 10 additions & 6 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,26 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
return internalServerError("Unable to process request").WithInternalError(err)
}
var flowState *models.FlowState
if isPKCEFlow(flowType) {
flowState, err = generateFlowState(models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.Recovery.String(), params.CodeChallenge, codeChallengeMethod, models.Recovery, &(user.ID)); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}

J0 marked this conversation as resolved.
Show resolved Hide resolved
externalURL := getExternalHost(ctx)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
})
Expand Down
15 changes: 6 additions & 9 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return err
}

var codeChallengeMethod models.CodeChallengeMethod
var err error
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
if codeChallengeMethod, err = models.ParseCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
return err
}
}

var user *models.User
var grantParams models.GrantParams

Expand Down Expand Up @@ -237,8 +230,12 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
if ok := isPKCEFlow(flowType); ok {
if terr := models.NewFlowStateWithUserID(tx, params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID); terr != nil {
if isPKCEFlow(flowType) {
flowState, terr := generateFlowState(params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
8 changes: 2 additions & 6 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)
var flowStateID *uuid.UUID
flowStateID = nil
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return err
}
flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML)
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() {
invalidVerifier := codeVerifier + "123"
codeChallenge := sha256.Sum256([]byte(codeVerifier))
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth)
require.NoError(ts.T(), err)
flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil)
flowState.AuthCode = authCode
require.NoError(ts.T(), ts.API.db.Create(flowState))
cases := []struct {
Expand Down
5 changes: 3 additions & 2 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
referrer := utilities.GetReferrer(r, config)
flowType := getFlowFromChallenge(params.CodeChallenge)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
flowState, terr := generateFlowState(models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.EmailChange.String(), params.CodeChallenge, codeChallengeMethod, models.EmailChange, &user.ID); terr != nil {

if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload))
codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type
err := models.NewFlowStateWithUserID(ts.API.db, c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), err)
flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), ts.API.db.Create(flowState))

requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token)
req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer)
Expand Down
18 changes: 2 additions & 16 deletions internal/models/flow_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,7 @@ func (FlowState) TableName() string {
return tableName
}

func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
ID: id,
ProviderType: providerType,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod.String(),
AuthCode: authCode.String(),
AuthenticationMethod: authenticationMethod.String(),
}
return flowState, nil
}

func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {
func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
Expand All @@ -107,7 +93,7 @@ func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge
AuthenticationMethod: authenticationMethod.String(),
UserID: userID,
}
return tx.Create(flowState)
return flowState
}

func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {
Expand Down
Loading