Skip to content

Commit

Permalink
feat: refactor PKCE FlowState to reduce duplicate code (supabase#1446)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

- Removes `NewFlowStateWithUserID` - it is sufficient to have one method
to create a flow state
- Compresses some of the PKCE checks into a single function

---------

Co-authored-by: joel <[email protected]>
Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
3 people authored Mar 5, 2024
1 parent 16c6528 commit b8d0337
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 56 deletions.
8 changes: 2 additions & 6 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,8 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
flowType := getFlowFromChallenge(codeChallenge)

flowStateID := ""
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return "", err
}
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return "", err
}
Expand Down
17 changes: 10 additions & 7 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,21 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, make(map[string]string))
}
var flowState *models.FlowState

if isPKCEFlow(flowType) {
flowState, err = generateFlowState(models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}

if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.MagicLink.String(), params.CodeChallenge, codeChallengeMethod, models.MagicLink, &user.ID); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
12 changes: 12 additions & 0 deletions internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"regexp"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
Expand Down Expand Up @@ -77,3 +78,14 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
return models.ImplicitFlow
}
}

// Should only be used with Auth Code of PKCE Flows
func generateFlowState(providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID)
return flowState, nil

}
15 changes: 9 additions & 6 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,22 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
return internalServerError("Unable to process request").WithInternalError(err)
}
var flowState *models.FlowState
if isPKCEFlow(flowType) {
flowState, err = generateFlowState(models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.Recovery.String(), params.CodeChallenge, codeChallengeMethod, models.Recovery, &(user.ID)); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
15 changes: 6 additions & 9 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return err
}

var codeChallengeMethod models.CodeChallengeMethod
var err error
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
if codeChallengeMethod, err = models.ParseCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
return err
}
}

var user *models.User
var grantParams models.GrantParams

Expand Down Expand Up @@ -237,8 +230,12 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
if ok := isPKCEFlow(flowType); ok {
if terr := models.NewFlowStateWithUserID(tx, params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID); terr != nil {
if isPKCEFlow(flowType) {
flowState, terr := generateFlowState(params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
8 changes: 2 additions & 6 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)
var flowStateID *uuid.UUID
flowStateID = nil
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return err
}
flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML)
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() {
invalidVerifier := codeVerifier + "123"
codeChallenge := sha256.Sum256([]byte(codeVerifier))
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth)
require.NoError(ts.T(), err)
flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil)
flowState.AuthCode = authCode
require.NoError(ts.T(), ts.API.db.Create(flowState))
cases := []struct {
Expand Down
5 changes: 3 additions & 2 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
referrer := utilities.GetReferrer(r, config)
flowType := getFlowFromChallenge(params.CodeChallenge)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
flowState, terr := generateFlowState(models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.EmailChange.String(), params.CodeChallenge, codeChallengeMethod, models.EmailChange, &user.ID); terr != nil {

if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload))
codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type
err := models.NewFlowStateWithUserID(ts.API.db, c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), err)
flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), ts.API.db.Create(flowState))

requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token)
req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer)
Expand Down
18 changes: 2 additions & 16 deletions internal/models/flow_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,7 @@ func (FlowState) TableName() string {
return tableName
}

func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
ID: id,
ProviderType: providerType,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod.String(),
AuthCode: authCode.String(),
AuthenticationMethod: authenticationMethod.String(),
}
return flowState, nil
}

func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {
func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
Expand All @@ -107,7 +93,7 @@ func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge
AuthenticationMethod: authenticationMethod.String(),
UserID: userID,
}
return tx.Create(flowState)
return flowState
}

func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {
Expand Down

0 comments on commit b8d0337

Please sign in to comment.