From e1cdf5c4b5c1bf467094f4bdcaa2e42a5cc51c20 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Mon, 4 Mar 2024 11:34:29 +0800 Subject: [PATCH] fix: refactor request params to use generics (#1464) ## What kind of change does this PR introduce? * Introduce a new method `retrieveRequestParams` which makes use of generics to parse a request * This will help to simplify parsing a request from: ```go params := RequestParams{} body, err := getBodyBytes(r) if err != nil { return nil, badRequestError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, ¶ms); err != nil { return nil, badRequestError("Could not decode request params: %v", err) } ``` to ```go params := &Request{} err := retrieveRequestParams(req, params) ``` ## TODO - [x] Add type constraint instead of using `any` --- internal/api/admin.go | 25 +++++++---------------- internal/api/anonymous.go | 4 ++-- internal/api/api.go | 4 ++-- internal/api/helpers.go | 38 +++++++++++++++++++++++++++++++++++ internal/api/invite.go | 12 +++-------- internal/api/mail.go | 12 +++-------- internal/api/mfa.go | 24 ++++++---------------- internal/api/middleware.go | 7 +------ internal/api/otp.go | 16 +++------------ internal/api/recover.go | 12 +++-------- internal/api/resend.go | 12 +++-------- internal/api/signup.go | 18 +++-------------- internal/api/sso.go | 14 ++++--------- internal/api/ssoadmin.go | 23 ++++++--------------- internal/api/token.go | 20 +++++------------- internal/api/token_oidc.go | 11 ++-------- internal/api/token_refresh.go | 11 ++-------- internal/api/user.go | 13 +++--------- internal/api/verify.go | 9 ++------- 19 files changed, 98 insertions(+), 187 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index d6fd17dac7..89f7af9754 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -85,18 +85,12 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex } func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { - params := AdminUserParams{} - - body, err := getBodyBytes(r) - if err != nil { - return nil, badRequestError("Could not read body").WithInternalError(err) + params := &AdminUserParams{} + if err := retrieveRequestParams(r, params); err != nil { + return nil, err } - if err := json.Unmarshal(body, ¶ms); err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) - } - - return ¶ms, nil + return params, nil } // adminUsers responds with a list of all users in a given audience @@ -565,16 +559,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro user := getUser(ctx) adminUser := getAdminUser(ctx) params := &adminUserUpdateFactorParams{} - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read factor update params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } - err = a.db.Transaction(func(tx *storage.Connection) error { + err := a.db.Transaction(func(tx *storage.Connection) error { if params.FriendlyName != "" { if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil { return terr diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 11412639e5..5316525a44 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -18,8 +18,8 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { return forbiddenError("Signups not allowed for this instance") } - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } params.Aud = aud diff --git a/internal/api/api.go b/internal/api/api.go index eb27c4dce3..73d810fa25 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -149,8 +149,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati DefaultExpirationTTL: time.Hour, }).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) r.Post("/", func(w http.ResponseWriter, r *http.Request) error { - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } if params.Email == "" && params.Phone == "" { diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 282ae46d34..ea4102f2e2 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -75,3 +75,41 @@ func isStringInSlice(checkValue string, list []string) bool { func getBodyBytes(req *http.Request) ([]byte, error) { return utilities.GetBodyBytes(req) } + +type RequestParams interface { + AdminUserParams | + CreateSSOProviderParams | + EnrollFactorParams | + GenerateLinkParams | + IdTokenGrantParams | + InviteParams | + OtpParams | + PKCEGrantParams | + PasswordGrantParams | + RecoverParams | + RefreshTokenGrantParams | + ResendConfirmationParams | + SignupParams | + SingleSignOnParams | + SmsParams | + UserUpdateParams | + VerifyFactorParams | + VerifyParams | + adminUserUpdateFactorParams | + struct { + Email string `json:"email"` + Phone string `json:"phone"` + } +} + +// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided +func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { + body, err := getBodyBytes(r) + if err != nil { + return internalServerError("Could not read body into byte slice").WithInternalError(err) + } + if err := json.Unmarshal(body, params); err != nil { + return badRequestError("Could not read request body: %v", err) + } + return nil +} diff --git a/internal/api/invite.go b/internal/api/invite.go index 65a651985e..45d94878c5 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "github.com/fatih/structs" @@ -24,16 +23,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { config := a.config adminUser := getAdminUser(ctx) params := &InviteParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Invite params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } + var err error params.Email, err = validateEmail(params.Email) if err != nil { return err diff --git a/internal/api/mail.go b/internal/api/mail.go index 057c451b89..0ab561ab77 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "net/url" "strings" @@ -49,16 +48,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { mailer := a.Mailer(ctx) adminUser := getAdminUser(ctx) params := &GenerateLinkParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse JSON: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } + var err error params.Email, err = validateEmail(params.Email) if err != nil { return err diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6be41e59d5..256d3e3d39 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "encoding/json" "fmt" "net/http" "net/url" @@ -66,15 +65,10 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { config := a.config params := &EnrollFactorParams{} - issuer := "" - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + if err := retrieveRequestParams(r, params); err != nil { + return err } + issuer := "" if params.FactorType != models.TOTP { return badRequestError("factor_type needs to be totp") @@ -206,16 +200,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { config := a.config params := &VerifyFactorParams{} - currentIP := utilities.GetIPAddress(r) - - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + if err := retrieveRequestParams(r, params); err != nil { + return err } + currentIP := utilities.GetIPAddress(r) if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index cf83f6629b..ab72e32c04 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -95,17 +95,12 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitEmail || shouldRateLimitPhone { if req.Method == "PUT" || req.Method == "POST" { - bodyBytes, err := getBodyBytes(req) - if err != nil { - return c, internalServerError("Error invalid request body").WithInternalError(err) - } - var requestBody struct { Email string `json:"email"` Phone string `json:"phone"` } - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + if err := retrieveRequestParams(req, &requestBody); err != nil { return c, badRequestError("Error invalid request body").WithInternalError(err) } diff --git a/internal/api/otp.go b/internal/api/otp.go index 752437bb6e..0e437faa1d 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -68,15 +68,10 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { params.Data = make(map[string]interface{}) } - body, err := getBodyBytes(r) - if err != nil { + if err := retrieveRequestParams(r, params); err != nil { return err } - if err = json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) - } - if err := params.Validate(); err != nil { return err } @@ -115,15 +110,10 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { var err error params := &SmsParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + if err := retrieveRequestParams(r, params); err != nil { + return err } - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) - } // For backwards compatibility, we default to SMS if params Channel is not specified if params.Phone != "" && params.Channel == "" { params.Channel = sms_provider.SMSProvider diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a57575650..dcf574d1d4 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "errors" "net/http" @@ -37,14 +36,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) config := a.config params := &RecoverParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } flowType := getFlowFromChallenge(params.CodeChallenge) @@ -53,6 +46,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } var user *models.User + var err error aud := a.requestAud(ctx, r) user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) diff --git a/internal/api/resend.go b/internal/api/resend.go index a49a55e7ed..cb8c4da240 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "errors" "net/http" "time" @@ -68,14 +67,8 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) config := a.config params := &ResendConfirmationParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } if err := params.Validate(config); err != nil { @@ -83,6 +76,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { } var user *models.User + var err error aud := a.requestAud(ctx, r) if params.Email != "" { user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) diff --git a/internal/api/signup.go b/internal/api/signup.go index e2c858f1f7..7093fe3beb 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "fmt" "net/http" "time" @@ -108,18 +107,6 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err return user, nil } -func retrieveSignupParams(r *http.Request) (*SignupParams, error) { - params := &SignupParams{} - body, err := getBodyBytes(r) - if err != nil { - return nil, internalServerError("Could not read body").WithInternalError(err) - } - if err := json.Unmarshal(body, params); err != nil { - return nil, badRequestError("Could not read Signup params: %v", err) - } - return params, nil -} - // Signup is the endpoint for registering a new user func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -130,8 +117,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return forbiddenError("Signups not allowed for this instance") } - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -142,6 +129,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } var codeChallengeMethod models.CodeChallengeMethod + var err error flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { diff --git a/internal/api/sso.go b/internal/api/sso.go index d93ff82dc0..0b4fd89073 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "github.com/crewjam/saml" @@ -41,17 +40,12 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params SingleSignOnParams - - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse request body as JSON").WithInternalError(err) + params := &SingleSignOnParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err } + var err error hasProviderID := false if hasProviderID, err = params.validate(); err != nil { diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 4fdecc0f81..0f966780ee 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "io" "net/http" "net/url" @@ -184,14 +183,9 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params CreateSSOProviderParams - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err } if err := params.validate(false /* <- forUpdate */); err != nil { @@ -264,14 +258,9 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params CreateSSOProviderParams - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err } if err := params.validate(true /* <- forUpdate */); err != nil { diff --git a/internal/api/token.go b/internal/api/token.go index 84970e8cb1..2f6f9e3b22 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -101,14 +100,8 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri db := a.db.WithContext(ctx) params := &PasswordGrantParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read password grant params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } aud := a.requestAud(ctx, r) @@ -120,6 +113,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri var user *models.User var grantParams models.GrantParams var provider string + var err error grantParams.FillGrantParams(r) @@ -236,13 +230,9 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) grantParams.FillGrantParams(r) params := &PKCEGrantParams{} - body, err := getBodyBytes(r) - if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - if err = json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + if err := retrieveRequestParams(r, params); err != nil { + return err } if params.AuthCode == "" || params.CodeVerifier == "" { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index c380856c33..5695cb4894 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -3,7 +3,6 @@ package api import ( "context" "crypto/sha256" - "encoding/json" "fmt" "net/http" @@ -114,14 +113,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R config := a.config params := &IdTokenGrantParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read id token grant params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } if params.IdToken == "" { diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index ebe4b5f2dc..65bbbb031b 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" mathRand "math/rand" "net/http" "time" @@ -26,14 +25,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h config := a.config params := &RefreshTokenGrantParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read refresh token grant params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } if params.RefreshToken == "" { diff --git a/internal/api/user.go b/internal/api/user.go index e31a3eceb6..ddf497f73e 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "time" @@ -84,14 +83,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { aud := a.requestAud(ctx, r) params := &UserUpdateParams{} - - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read User Update params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } user := getUser(ctx) @@ -170,7 +163,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } } - err = db.Transaction(func(tx *storage.Connection) error { + err := db.Transaction(func(tx *storage.Connection) error { var terr error if params.Password != nil { var sessionID *uuid.UUID diff --git a/internal/api/verify.go b/internal/api/verify.go index 35f8253eb9..6dd29be05b 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -107,12 +106,8 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } return a.verifyGet(w, r, params) case http.MethodPost: - body, err := getBodyBytes(r) - if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse verification params: %v", err) + if err := retrieveRequestParams(r, params); err != nil { + return err } if err := params.Validate(r); err != nil { return err