Skip to content

Commit

Permalink
feat: refactor PKCE
Browse files Browse the repository at this point in the history
Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
2 people authored and joel committed Feb 21, 2024
1 parent 1ea56b6 commit 8827b71
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 58 deletions.
9 changes: 3 additions & 6 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
flowType := getFlowFromChallenge(codeChallenge)

flowStateID := ""
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return "", err
}
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
if isPKCEFlow(flowType) {
codeChallengeMethodType, err := models.MapCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return "", err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth, nil)
if err := a.db.Create(flowState); err != nil {
return "", err
}
Expand Down
13 changes: 6 additions & 7 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,17 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, make(map[string]string))
}
flowState, err := generateFlowStateIfPKCE(flowType, models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}

if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.MagicLink.String(), params.CodeChallenge, codeChallengeMethod, models.MagicLink, &user.ID); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
14 changes: 14 additions & 0 deletions internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"regexp"
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
Expand Down Expand Up @@ -78,3 +79,16 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
return models.ImplicitFlow
}
}

func generateFlowStateIfPKCE(flowType models.FlowType, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
if !isPKCEFlow(flowType) {
return nil, nil
}
codeChallengeMethod, err := models.MapCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
flowState := models.NewFlowState(authenticationMethod.String(), codeChallenge, codeChallengeMethod, authenticationMethod, userID)
return flowState, nil

}
14 changes: 8 additions & 6 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,23 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
return internalServerError("Unable to process request").WithInternalError(err)
}

flowState, err := generateFlowStateIfPKCE(flowType, models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
return err
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.Recovery.String(), params.CodeChallenge, codeChallengeMethod, models.Recovery, &(user.ID)); terr != nil {
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}

externalURL := getExternalHost(ctx)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
})
Expand Down
7 changes: 4 additions & 3 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
if codeChallengeMethod, err = models.ParseCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
if codeChallengeMethod, err = models.MapCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
return err
}
}
Expand Down Expand Up @@ -227,8 +227,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
if ok := isPKCEFlow(flowType); ok {
if terr := models.NewFlowStateWithUserID(tx, params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID); terr != nil {
if isPKCEFlow(flowType) {
flowState := models.NewFlowState(params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID)
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
15 changes: 6 additions & 9 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)
var flowStateID *uuid.UUID
flowStateID = nil
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return err
}
flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML)
if err != nil {
return err
}
flowState, err := generateFlowStateIfPKCE(flowType, models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}

if flowState != nil {
if err := a.db.Create(flowState); err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() {
invalidVerifier := codeVerifier + "123"
codeChallenge := sha256.Sum256([]byte(codeVerifier))
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth)
require.NoError(ts.T(), err)
flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil)
flowState.AuthCode = authCode
require.NoError(ts.T(), ts.API.db.Create(flowState))
cases := []struct {
Expand Down
13 changes: 7 additions & 6 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,13 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
flowType := getFlowFromChallenge(params.CodeChallenge)
if isPKCEFlow(flowType) {
codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod)
if terr != nil {
return terr
}
if terr := models.NewFlowStateWithUserID(tx, models.EmailChange.String(), params.CodeChallenge, codeChallengeMethod, models.EmailChange, &user.ID); terr != nil {
flowState, err := generateFlowStateIfPKCE(flowType, models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}

if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload))
codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type
err := models.NewFlowStateWithUserID(ts.API.db, c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), err)
flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID)
require.NoError(ts.T(), ts.API.db.Create(flowState))

requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token)
req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer)
Expand Down
20 changes: 3 additions & 17 deletions internal/models/flow_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (codeChallengeMethod CodeChallengeMethod) String() string {
return ""
}

func ParseCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) {
func MapCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) {
switch strings.ToLower(codeChallengeMethod) {
case "s256":
return SHA256, nil
Expand Down Expand Up @@ -81,21 +81,7 @@ func (FlowState) TableName() string {
return tableName
}

func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
ID: id,
ProviderType: providerType,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod.String(),
AuthCode: authCode.String(),
AuthenticationMethod: authenticationMethod.String(),
}
return flowState, nil
}

func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {
func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState {
id := uuid.Must(uuid.NewV4())
authCode := uuid.Must(uuid.NewV4())
flowState := &FlowState{
Expand All @@ -107,7 +93,7 @@ func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge
AuthenticationMethod: authenticationMethod.String(),
UserID: userID,
}
return tx.Create(flowState)
return flowState
}

func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {
Expand Down

0 comments on commit 8827b71

Please sign in to comment.