From e4beea1cdb80544b0581f1882696a698fdf64938 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Wed, 13 Mar 2024 12:56:37 +0100 Subject: [PATCH] feat: add error codes (#1377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds proper error codes with API versioning. From now on, all responses that end in a 4XX HTTP status code will include a textual description of the error that occurred. Error responses on API versions before `2024-01-01` have this schema: ```json { "code": "", "msg": "", "error_code": "" } ``` Error responses on API version on or after `2024-01-01` have this schema: ```json { "code": "", "message": "" } ``` API versions are controlled by submitting an `X-Supabase-Api-Version` header to the request. A missing or invalid value assumes the "initial" API version as used before the introduction of API version `2024-01-01`. Error code contract for API version `2024-01-01`: 1. Error codes will not be renamed. You can safely rely on them. 2. HTTP status codes are _mostly_ fixed, but you should not rely on them except the class 4XX vs 5XX. 3. Error messages are a _developer aid._ They may change across deployed version. You should not rely on them, but if you want you can show them to your users. Error translations should be based on the error code! Of the 4XX HTTP status code class, only these codes are allowed to be used in API version `2024-01-01` according to these rules. The purpose of this is to keep proper HTTP semantics. The tuple `(http_status_code, error_code)` shouldn't be used by clients!
HTTP Status CodeWhen to use?Primary Fault At
400 Bad RequestInputs (body, headers) and their contents are not valid as a whole, or parts of them. Example: bad JSON, bad JSON object, using two mutually exclusive JSON fields, missing required fields, wrong encoding…

If the answer to this question is
yes then you should probably use 400: Is there some technical thing the developer should do to get a different status code?
Developer.

MUST NEVER OCCUR WHEN USING AN OFFICIAL SUPABASE LIBRARY.

Why?
- Library should not send invalid requests.
- If occurring, means: improper types, no client-side validation.
401 UnauthorizedYou must use this code if security headers or inputs are missing, and are not valid to some extent. Example: missing JWT, missing CAPTCHA token, missing important query params that serve to authenticate the caller.

You may use 400 instead if the security headers or inputs are provided and relatively valid (valid JWT signature, bad claims) instead, though prefer 401 over 400 unless it aids in DX.

Do not use this code to signal that the user does not have sufficient application privileges.

If the answer to this question is
yes then you should use 401: Are the credentials the user/client sending missing or invalid in form, structure, encoding?
Developer.

MUST NEVER OCCUR WHEN USING AN OFFICIAL SUPABASE LIBRARY.

Why?
- Library should never send improper requests (missing authorization headers for features that require authorization).
- If occurring means: broken logic, improper types, no client-side validation.
403 ForbiddenDo not use this code for bad JWT format, missing headers or other validation on security sensitive payloads. Return 400 on those.

Once security payloads have been validated in structure, only return this error if the user/client can be authenticated properly but they do not possess the proper authorization to access the resource.

If the answer to this question is
yes then you should use 403: Should the user/client be someone else to get a different status code?

In some cases you should prefer 200 responses with empty bodies, akin to RLS behavior.
User.

Developer is at fault for not hiding the feature sufficiently.

Can rarely occur when using Supabase libraries, and in such cases it means docs / explanation problems.
404 Not FoundDo not use this for missing objects in the database. Prefer using 422 instead.

Use only if the URL cannot be fully validated, resulting in a resource that cannot be properly identified. If there’s no variables in the path, this code must not be used.

Good:
- /users/<not-uuid>
- /users/<uuid> (but such a user does not exist)

Bad:
- /token?grant_type=password (no variables in URL)
- /sessions (no variables in URL)

For cases where a feature is disabled on a server consider using 501 or 422.

For requests that “look up” entities consider using 200 with an empty/null response body or 204 No Content instead.
Developer.

MUST NEVER OCCUR WHEN USING AN OFFICIAL SUPABASE LIBRARY.

Why?
- Library should never send improperly formatted URLs, or encode data in URLs that it knows to be invalid.
- Ideally library should not take in freeform input about entities, and should use some “proof” that the entity exists. Example: calling methods on objects returned by a list/find-by-id method.
- In some situations it’s inevitable (like in admin APIs).
422 Unprocessable EntityDo not use for bad inputs!

Once all inputs to a request have been validated to the fullest extent possible (e.g. OK to validate an email address format, but not necessary to validate that there’s a SMTP server listening), use this status code to signal errors with the processing logic. This includes all logic dependent on database state (user exists, or doesn’t). All third-party expected errors (like calling into a third service) should end with 422.

If the answer to this question is
yes then you should use 422: Is there something different that the user should do to get a different status code?
User.

Developer is at fault for using the feature in an improper part of the flow.

Can rarely occur when using Supabase libraries, and in such cases it means docs / explanation problems.
429 Too Many RequestsOnly use this for rate-limiting or other cases that limit the number of requests. Third-party rate-limits should be propagated with this error.User.

Developer should have implemented a better UX to prevent reaching this error for legitimate users. (Disabling buttons, adding timeout UI elements…)
500 Internal Server ErrorUse this for any unexpected errors when processing a request. Default to this code if you can’t find a 4XX error code for it.Supabase. Developer in some cases (such as when changing database contents).

The cause of this error should not be situations arising from official Supabase libraries (such as sending inputs that crash the server).

501 Not ImplementedA feature is disabled, not configured (properly), blocked or otherwise unavailable.Developer, for not enabling or configuring features properly.
--------- Co-authored-by: joel --- .../workflows/conventional-commits-lint.js | 7 +- internal/api/admin.go | 29 ++- internal/api/anonymous.go | 2 +- internal/api/api.go | 2 +- internal/api/apiversions.go | 35 +++ internal/api/apiversions_test.go | 29 +++ internal/api/audit.go | 4 +- internal/api/auth.go | 22 +- internal/api/auth_test.go | 4 +- internal/api/errorcodes.go | 77 ++++++ internal/api/errors.go | 245 +++++++++--------- internal/api/errors_test.go | 64 +++++ internal/api/external.go | 53 ++-- internal/api/external_figma_test.go | 2 +- internal/api/external_fly_test.go | 2 +- internal/api/external_github_test.go | 4 +- internal/api/external_kakao_test.go | 4 +- internal/api/external_oauth.go | 11 +- internal/api/helpers.go | 2 +- internal/api/helpers_test.go | 8 +- internal/api/hooks.go | 16 +- internal/api/identity.go | 31 ++- internal/api/identity_test.go | 8 +- internal/api/invite.go | 2 +- internal/api/invite_test.go | 2 +- internal/api/logout.go | 2 +- internal/api/magic_link.go | 17 +- internal/api/mail.go | 116 ++++++--- internal/api/mfa.go | 24 +- internal/api/mfa_test.go | 11 +- internal/api/middleware.go | 17 +- internal/api/middleware_test.go | 8 +- internal/api/otp.go | 24 +- internal/api/otp_test.go | 33 ++- internal/api/phone.go | 2 +- internal/api/phone_test.go | 23 +- internal/api/pkce.go | 8 +- internal/api/reauthenticate.go | 23 +- internal/api/recover.go | 4 +- internal/api/resend.go | 21 +- internal/api/router.go | 4 +- internal/api/samlacs.go | 24 +- internal/api/signup.go | 45 ++-- internal/api/sso.go | 8 +- internal/api/sso_test.go | 2 +- internal/api/ssoadmin.go | 36 +-- internal/api/token.go | 20 +- internal/api/token_oidc.go | 6 +- internal/api/token_test.go | 4 +- internal/api/user.go | 25 +- internal/api/user_test.go | 2 +- internal/api/verify.go | 46 ++-- internal/api/verify_test.go | 22 +- 53 files changed, 787 insertions(+), 455 deletions(-) create mode 100644 internal/api/apiversions.go create mode 100644 internal/api/apiversions_test.go create mode 100644 internal/api/errorcodes.go create mode 100644 internal/api/errors_test.go diff --git a/.github/workflows/conventional-commits-lint.js b/.github/workflows/conventional-commits-lint.js index a3815e5e7..96d6c9828 100644 --- a/.github/workflows/conventional-commits-lint.js +++ b/.github/workflows/conventional-commits-lint.js @@ -46,7 +46,12 @@ let failed = false; validate.forEach((payload) => { if (payload.title) { - const { groups } = payload.title.match(TITLE_PATTERN); + const match = payload.title.match(TITLE_PATTERN); + if (!match) { + return + } + + const { groups } = match if (groups) { if (groups.breaking) { diff --git a/internal/api/admin.go b/internal/api/admin.go index 89f7af975..f7acf3b45 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -50,7 +50,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, badRequestError("user_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -58,7 +58,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") + return nil, notFoundError(ErrorCodeUserNotFound, "User not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -69,7 +69,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, badRequestError("factor_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex f, err := models.FindFactorByFactorID(a.db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Factor not found") + return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } @@ -101,12 +101,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") @@ -160,7 +160,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -308,7 +308,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") + return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") } var providers []string @@ -320,7 +320,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } providers = append(providers, "email") } @@ -333,7 +333,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError("Phone number already registered by another user") + return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user") } providers = append(providers, "phone") } @@ -429,7 +429,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -460,11 +460,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { params := &adminUserDeleteParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if len(body) > 0 { if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } } else { params.ShouldSoftDelete = false @@ -559,6 +559,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro user := getUser(ctx) adminUser := getAdminUser(ctx) params := &adminUserUpdateFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -571,7 +572,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro } if params.FactorType != "" { if params.FactorType != models.TOTP { - return badRequestError("Factor Type not valid") + return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid") } if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil { return terr diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 5316525a4..4024d5947 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -15,7 +15,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { aud := a.requestAud(ctx, r) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} diff --git a/internal/api/api.go b/internal/api/api.go index 73d810fa2..edac716a6 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -155,7 +155,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati } if params.Email == "" && params.Phone == "" { if !api.config.External.AnonymousUsers.Enabled { - return unprocessableEntityError("Anonymous sign-ins are disabled") + return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") } if _, err := api.limitHandler(limiter)(w, r); err != nil { return err diff --git a/internal/api/apiversions.go b/internal/api/apiversions.go new file mode 100644 index 000000000..b5394a5fc --- /dev/null +++ b/internal/api/apiversions.go @@ -0,0 +1,35 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} + +func FormatAPIVersion(apiVersion APIVersion) string { + return apiVersion.Format("2006-01-02") +} diff --git a/internal/api/apiversions_test.go b/internal/api/apiversions_test.go new file mode 100644 index 000000000..0a9622132 --- /dev/null +++ b/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/internal/api/audit.go b/internal/api/audit.go index 2cb99c6e7..351a7d2cd 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } diff --git a/internal/api/auth.go b/internal/api/auth.go index dbd0278bd..3e69d8c6c 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -39,7 +39,7 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte ctx := r.Context() claims := getClaims(ctx) if claims.IsAnonymous { - return nil, forbiddenError("Anonymous user not allowed to perform these actions") + return nil, forbiddenError(ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") } return ctx, nil } @@ -49,7 +49,7 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex claims := getClaims(ctx) if claims == nil { fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") + return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -60,14 +60,14 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex } fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") + return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed") } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") + return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") } return matches[1], nil @@ -82,7 +82,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err) + return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -93,23 +93,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, unauthorizedError("invalid token: missing claims") + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims") } if claims.Subject == "" { - return nil, unauthorizedError("invalid claim: missing sub claim") + return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, notFoundError(err.Error()) + return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") } return ctx, err } @@ -120,11 +120,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, err + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil && !models.IsNotFoundError(err) { - return ctx, err + return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist") } ctx = withSession(ctx, session) } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 6c95a0bd9..f404e1cb7 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -96,7 +96,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: unauthorizedError("invalid claim: missing sub claim"), + ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"), ExpectedUser: nil, }, { @@ -118,7 +118,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"), + ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), ExpectedUser: nil, }, { diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go new file mode 100644 index 000000000..45dec0dd7 --- /dev/null +++ b/internal/api/errorcodes.go @@ -0,0 +1,77 @@ +package api + +type ErrorCode = string + +const ( + // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. + ErrorCodeUnknown ErrorCode = "unknown" + + // ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error. + ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure" + + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeSingleIdentityNotDeletable ErrorCode = "single_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeReauthNonceMissing ErrorCode = "reauth_nonce_missing" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" + ErrorCodeOverRequestRateLimit ErrorCode = "over_request_rate_limit" + ErrorCodeOverEmailSendRateLimit ErrorCode = "over_email_send_rate_limit" + ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" + ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" + ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" +) diff --git a/internal/api/errors.go b/internal/api/errors.go index 56f404e3c..cc6ba877b 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -8,7 +8,6 @@ import ( "runtime/debug" "github.com/pkg/errors" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/utilities" ) @@ -65,65 +64,43 @@ func (e *OAuthError) Cause() error { return e } -func invalidSignupError(config *conf.GlobalConfiguration) *HTTPError { - var msg string - if config.External.Email.Enabled && config.External.Phone.Enabled { - msg = "To signup, please provide your email or phone number" - } else if config.External.Email.Enabled { - msg = "To signup, please provide your email" - } else if config.External.Phone.Enabled { - msg = "To signup, please provide your phone number" - } else { - // 3rd party OAuth signups - msg = "To signup, please provide required fields" - } - return unprocessableEntityError(msg) -} - func oauthError(err string, description string) *OAuthError { return &OAuthError{Err: err, Description: description} } -func badRequestError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, fmtString, args...) +func badRequestError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusBadRequest, errorCode, fmtString, args...) } func internalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, fmtString, args...) -} - -func notFoundError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, fmtString, args...) + return httpError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...) } -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func notFoundError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusNotFound, errorCode, fmtString, args...) } -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func forbiddenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusForbidden, errorCode, fmtString, args...) } -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusForbidden, fmtString, args...) +func unprocessableEntityError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) } -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnprocessableEntity, fmtString, args...) -} - -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusTooManyRequests, fmtString, args...) +func tooManyRequestsError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) } func conflictError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusConflict, fmtString, args...) + return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! InternalError error `json:"-"` InternalMessage string `json:"-"` ErrorID string `json:"error_id,omitempty"` @@ -133,7 +110,7 @@ func (e *HTTPError) Error() string { if e.InternalMessage != "" { return e.InternalMessage } - return fmt.Sprintf("%d: %s", e.Code, e.Message) + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) } func (e *HTTPError) Is(target error) bool { @@ -160,50 +137,12 @@ func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) * return e } -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { +func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - } -} - -// OTPError is a custom error struct for phone auth errors -type OTPError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OTPError) WithInternalError(err error) *OTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: fmt.Sprintf(fmtString, args...), } - return e -} - -func otpError(err string, description string) *OTPError { - return &OTPError{Err: err, Description: description} } // Recoverer is a middleware that recovers from panics, logs the panic (and a @@ -222,10 +161,10 @@ func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) } se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), } - handleError(se, w, r) + HandleResponseError(se, w, r) } }() @@ -237,28 +176,61 @@ type ErrorCause interface { Cause() error } -func handleError(err error, w http.ResponseWriter, r *http.Request) { +type HTTPErrorResponse20240101 struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { log := observability.GetLogEntry(r) errorID := getRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } else if apiVersion != APIVersionInitial { + // Echo back the determined API version from the request + w.Header().Set(APIVersionHeaderName, FormatAPIVersion(apiVersion)) + } + switch e := err.(type) { case *WeakPasswordError: - var output struct { - HTTPError - Payload struct { - Reasons []string `json:"reasons,omitempty"` - } `json:"weak_password,omitempty"` - } + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } - output.Code = http.StatusUnprocessableEntity - output.Message = e.Message - output.Payload.Reasons = e.Reasons + output.Code = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.Code, output); jsonErr != nil { - handleError(jsonErr, w, r) + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.ErrorCode = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } case *HTTPError: - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -266,35 +238,76 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) { log.WithError(e.Cause()).Info(e.Error()) } - // Provide better error messages for certain user-triggered Postgres errors. - if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { - handleError(jsonErr, w, r) + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if resp.Code == "" { + if e.HTTPStatus == http.StatusInternalServerError { + resp.Code = ErrorCodeUnexpectedFailure + } else { + resp.Code = ErrorCodeUnknown + } + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + if e.ErrorCode == "" { + if e.HTTPStatus == http.StatusInternalServerError { + e.ErrorCode = ErrorCodeUnexpectedFailure + } else { + e.ErrorCode = ErrorCodeUnknown + } } - return - } - if jsonErr := sendJSON(w, e.Code, e); jsonErr != nil { - handleError(jsonErr, w, r) + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + return + } + + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } + case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OTPError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) + HandleResponseError(jsonErr, w, r) } + case ErrorCause: - handleError(e.Cause(), w, r) + HandleResponseError(e.Cause(), w, r) + default: log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) - // hide real error details from response to prevent info leaks - w.WriteHeader(http.StatusInternalServerError) - if _, writeErr := w.Write([]byte(`{"code":500,"msg":"Internal server error","error_id":"` + errorID + `"}`)); writeErr != nil { - log.WithError(writeErr).Error("Error writing generic error message") + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + httpError := HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } } } diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go new file mode 100644 index 000000000..fc6135205 --- /dev/null +++ b/internal/api/errors_test.go @@ -0,0 +1,64 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + Message: "Uncoded failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnknown + "\",\"message\":\"Uncoded failure\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "Unexpected failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnexpectedFailure + "\",\"message\":\"Unexpected failure\"}", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, example.HTTPError.HTTPStatus, rec.Code) + require.Equal(t, example.ExpectedBody, rec.Body.String()) + } +} diff --git a/internal/api/external.go b/internal/api/external.go index 177d20059..8fa27f4ae 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,7 +64,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(userErr.Error()) + return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found") } return "", internalServerError("Database error finding user").WithInternalError(userErr) } @@ -127,6 +127,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ } authURL := p.AuthCodeURL(tokenString, authUrlParams...) + return authURL, nil } @@ -196,9 +197,12 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) - if err != nil { - return err + if models.IsNotFoundError(err) { + return unprocessableEntityError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) } + } var user *models.User @@ -300,7 +304,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, forbiddenError("Signups not allowed for this instance") + return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{ @@ -347,14 +351,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError(fmt.Sprintf("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)) + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError(fmt.Sprintf("Unknown automatic linking decision: %v", decision.Decision)) + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, unauthorizedError("User is unauthorized") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } if !user.IsConfirmed() { @@ -383,7 +387,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every minute") } return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -391,9 +395,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -411,7 +415,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found") } return nil, internalServerError("Database error finding user").WithInternalError(err) } @@ -427,7 +431,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p } if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -480,8 +484,11 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(config.JWT.Secret), nil }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -495,12 +502,12 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError("invalid target user id") + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Linking target user not found") + return nil, unprocessableEntityError(ErrorCodeUserNotFound, "Linking target user not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -591,12 +598,18 @@ func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http. func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { switch e := err.(type) { case *HTTPError: - if str, ok := oauthErrorMap[e.Code]; ok { + if e.ErrorCode == ErrorCodeSignupDisabled { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeUserBanned { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeProviderEmailNeedsVerification { + q.Set("error", "access_denied") + } else if str, ok := oauthErrorMap[e.HTTPStatus]; ok { q.Set("error", str) } else { q.Set("error", "server_error") } - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -604,7 +617,7 @@ func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q ur log.WithError(e.Cause()).Info(e.Error()) } q.Set("error_description", e.Message) - q.Set("error_code", strconv.Itoa(e.Code)) + q.Set("error_code", strconv.Itoa(e.HTTPStatus)) case *OAuthError: q.Set("error", e.Err) q.Set("error_description", e.Description) diff --git a/internal/api/external_figma_test.go b/internal/api/external_figma_test.go index 56d2f478d..bd7a8c29c 100644 --- a/internal/api/external_figma_test.go +++ b/internal/api/external_figma_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "figma", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_fly_test.go b/internal/api/external_fly_test.go index c469f2900..3c33c53e2 100644 --- a/internal/api/external_fly_test.go +++ b/internal/api/external_fly_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "fly", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_github_test.go b/internal/api/external_github_test.go index b3ad58440..f6f4334d7 100644 --- a/internal/api/external_github_test.go +++ b/internal/api/external_github_test.go @@ -276,7 +276,7 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { u := performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { @@ -296,5 +296,5 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_kakao_test.go b/internal/api/external_kakao_test.go index 7882e1dce..cd2bd2b29 100644 --- a/internal/api/external_kakao_test.go +++ b/internal/api/external_kakao_test.go @@ -214,7 +214,7 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenVerifiedFalse() { u := performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { @@ -234,5 +234,5 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 5352299ac..6c0972ea8 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" @@ -30,7 +31,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } if state == "" { - return nil, badRequestError("OAuth state parameter missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } ctx := r.Context() @@ -60,12 +61,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r) @@ -107,7 +108,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) @@ -145,6 +146,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/helpers.go b/internal/api/helpers.go index ea4102f2e..d771dca40 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -109,7 +109,7 @@ func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { return internalServerError("Could not read body into byte slice").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read request body: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) } return nil } diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index ec5812e09..15f9ce4d6 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -16,12 +16,12 @@ func TestIsValidCodeChallenge(t *testing.T) { { challenge: "invalid", isValid: false, - expectedError: badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), }, { challenge: "codechallengecontainsinvalidcharacterslike@$^&*", isValid: false, - expectedError: badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), }, { challenge: "validchallengevalidchallengevalidchallengevalidchallenge", @@ -56,12 +56,12 @@ func TestIsValidPKCEParmas(t *testing.T) { { challengeMethod: "test", challenge: "", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, { challengeMethod: "", challenge: "test", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, } diff --git a/internal/api/hooks.go b/internal/api/hooks.go index f3a9e11f1..5368339d8 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -80,8 +80,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -106,8 +106,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -132,8 +132,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -146,8 +146,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: err.Error(), + HTTPStatus: httpCode, + Message: err.Error(), } return httpError diff --git a/internal/api/identity.go b/internal/api/identity.go index f47708555..858810f70 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -2,7 +2,6 @@ package api import ( "context" - "fmt" "net/http" "github.com/fatih/structs" @@ -20,22 +19,22 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") - } - - aud := a.requestAud(ctx, r) - if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return internalServerError("Could not read claims") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return badRequestError("identity_id must be an UUID") + return notFoundError(ErrorCodeValidationFailed, "identity_id must be an UUID") + } + + aud := a.requestAud(ctx, r) + if aud != claims.Audience { + return forbiddenError(ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") } user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("User must have at least 1 identity after unlinking") + return unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -46,7 +45,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return badRequestError("Identity doesn't exist") + return unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -73,7 +72,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { default: if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + return unprocessableEntityError(ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) } return internalServerError("Database error updating user email").WithInternalError(terr) } @@ -117,9 +116,9 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora } if identity != nil { if identity.UserID == targetUser.ID { - return nil, badRequestError("Identity is already linked") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked") } - return nil, badRequestError("Identity is already linked to another user") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr @@ -128,7 +127,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora if targetUser.GetEmail() == "" { if terr := targetUser.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return nil, badRequestError(DuplicateEmailMsg) + return nil, badRequestError(ErrorCodeEmailExists, DuplicateEmailMsg) } return nil, terr } @@ -138,10 +137,10 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora externalURL := getExternalHost(ctx) if terr := sendConfirmation(tx, targetUser, mailer, a.config.SMTP.MaxFrequency, referrer, externalURL, a.config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "For security purposes, you can only request this once every minute") } } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) } if terr := targetUser.Confirm(tx); terr != nil { return nil, terr diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index 7f70af416..2b193cd21 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -101,7 +101,7 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { }, } u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testExistingUserData, "email") - require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.ErrorIs(ts.T(), err, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked")) require.Nil(ts.T(), u) } @@ -122,13 +122,13 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() { desc: "User must have at least 1 identity after unlinking", user: userWithOneIdentity, identityId: userWithOneIdentity.Identities[0].ID, - expectedError: badRequestError("User must have at least 1 identity after unlinking"), + expectedError: unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), }, { desc: "Identity doesn't exist", user: userWithTwoIdentities, identityId: uuid.Must(uuid.NewV4()), - expectedError: badRequestError("Identity doesn't exist"), + expectedError: unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist"), }, } @@ -141,7 +141,7 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), c.expectedError.Code, w.Code) + require.Equal(ts.T(), c.expectedError.HTTPStatus, w.Code) var data HTTPError require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) diff --git a/internal/api/invite.go b/internal/api/invite.go index 45d94878c..2e912b79c 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -42,7 +42,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/invite_test.go b/internal/api/invite_test.go index 466682028..c525e8747 100644 --- a/internal/api/invite_test.go +++ b/internal/api/invite_test.go @@ -162,7 +162,7 @@ func (ts *InviteTestSuite) TestInvite_WithoutAccess() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) // 401 OK because the invite request above has no Authorization header } func (ts *InviteTestSuite) TestVerifyInvite() { diff --git a/internal/api/logout.go b/internal/api/logout.go index ad95b22a4..cd1394eda 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index ddd3dba37..c0aaded7a 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -24,7 +25,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return unprocessableEntityError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error p.Email, err = validateEmail(p.Email) @@ -44,14 +45,14 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(); err != nil { @@ -82,7 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -94,7 +95,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) r.ContentLength = int64(len(string(newBodyContent))) @@ -113,7 +115,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } metadata, err := json.Marshal(newBodyContent) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(bytes.NewReader(metadata)) return a.MagicLink(w, r) @@ -143,7 +146,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending magic link").WithInternalError(err) } diff --git a/internal/api/mail.go b/internal/api/mail.go index 0ab561ab7..448f5a038 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -66,14 +66,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { - if params.Type == magicLinkVerification { + switch params.Type { + case magicLinkVerification: params.Type = signupVerification params.Password, err = password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + // password generation must always succeed + panic(err) } - } else if params.Type == recoveryVerification || params.Type == "email_change_current" || params.Type == "email_change_new" { - return notFoundError(err.Error()) + + default: + return notFoundError(ErrorCodeUserNotFound, "User with this email not found") } } else { return internalServerError("Database error finding user").WithInternalError(err) @@ -84,7 +87,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { now := time.Now() otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) if err != nil { - return err + // OTP generation must always succeed + panic(err) } hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -118,11 +122,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.RecoveryToken = hashedToken user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + } case inviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -162,11 +169,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + } case signupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) @@ -191,19 +201,22 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + } case "email_change_current", "email_change_new": if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return unprocessableEntityError("Enable secure email change to generate link for current email") + return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") } params.NewEmail, terr = validateEmail(params.NewEmail) if terr != nil { - return unprocessableEntityError("The new email address provided is invalid") + return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { return internalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -214,9 +227,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } else if params.Type == "email_change_new" { user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) } - terr = errors.Wrap(tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status"), "Database error updating user for email change") + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + } default: - return badRequestError("Invalid email action link type requested: %v", params.Type) + return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -255,7 +271,8 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) @@ -265,7 +282,12 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(err, "Error sending confirmation email") } u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for confirmation") + } + + return nil } func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { @@ -273,7 +295,8 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -283,7 +306,12 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.InvitedAt = &now u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for invite") + } + + return nil } func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -295,7 +323,8 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -305,7 +334,12 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile return errors.Wrap(err, "Error sending recovery email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error { @@ -317,7 +351,8 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma oldToken := u.ReauthenticationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -326,7 +361,12 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma return errors.Wrap(err, "Error sending reauthentication email") } u.ReauthenticationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") + err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for reauthentication") + } + + return nil } func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -339,7 +379,8 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -350,7 +391,12 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile return errors.Wrap(err, "Error sending magic link email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } // sendEmailChange sends out an email change token to the new email. @@ -361,7 +407,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } otpNew, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) @@ -371,7 +418,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { otpCurrent, err = crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) @@ -384,22 +432,28 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( + err = tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ), "Database error updating user for email change") + ) + + if err != nil { + return errors.Wrap(err, "Database error updating user for email change") + } + + return nil } func validateEmail(email string) (string, error) { if email == "" { - return "", unprocessableEntityError("An email address is required") + return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") } if err := checkmail.ValidateFormat(email); err != nil { - return "", unprocessableEntityError("Unable to validate email address: " + err.Error()) + return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6bdb3e596..3919cb781 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -73,11 +73,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { return err } - issuer := "" if params.FactorType != models.TOTP { - return badRequestError("factor_type needs to be totp") + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp") } + issuer := "" if params.Issuer == "" { u, err := url.ParseRequestURI(config.SiteURL) if err != nil { @@ -103,15 +103,15 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of verified factors reached, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError("AAL2 required to enroll a new factor") + return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") } key, err := totp.Generate(totp.GenerateOpts{ @@ -138,7 +138,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if terr := tx.Create(factor); terr != nil { pgErr := utilities.NewPostgresError(terr) if pgErr.IsUniqueConstraintViolated() { - return badRequestError(fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) + return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) } return terr @@ -216,20 +216,20 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByID(a.db, params.ChallengeID) if err != nil && models.IsNotFoundError(err) { - return notFoundError(err.Error()) + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") } else if err != nil { return internalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return badRequestError("Challenge and verify IP addresses mismatch") + return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { if err := a.db.Destroy(challenge); err != nil { return internalServerError("Database error deleting challenge").WithInternalError(err) } - return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } valid := totp.Validate(params.Code, factor.Secret) @@ -257,11 +257,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output.Message = hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(output.Message) + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { - return badRequestError("Invalid TOTP code entered") + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered") } var token *AccessTokenResponse @@ -322,7 +322,7 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { } if factor.IsVerified() && !session.IsAAL2() { - return badRequestError("AAL2 required to unenroll verified factor") + return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index bb3c91968..39ec9f2cc 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -168,16 +168,15 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { issuer := "https://issuer.com" token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusOK) - response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusBadRequest) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusUnprocessableEntity) var errorResponse HTTPError err := json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", friendlyName) + expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) - } func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { @@ -226,13 +225,13 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { desc: "Invalid: Valid code and expired challenge", validChallenge: false, validCode: true, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Invalid: Invalid code and valid challenge ", validChallenge: true, validCode: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Valid /verify request", @@ -309,7 +308,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { { desc: "Verified Factor: AAL1", isAAL2: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Verified Factor: AAL2, Success", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index ab72e32c0..6a6d68a25 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -66,7 +66,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } else { err := tollbooth.LimitByKeys(lmt, []string{key}) if err != nil { - return c, httpError(http.StatusTooManyRequests, "Rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverRequestRateLimit, "Request rate limit reached") } } } @@ -101,7 +101,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if err := retrieveRequestParams(req, &requestBody); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) + return c, err } if shouldRateLimitEmail { @@ -112,7 +112,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { 1, attribute.String("path", req.URL.Path), ) - return c, httpError(http.StatusTooManyRequests, "Email rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "Email rate limit exceeded") } } } @@ -120,7 +120,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitPhone { if requestBody.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } } } @@ -151,7 +151,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") + return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil @@ -178,8 +178,7 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } if !verificationResult.Success { - return nil, badRequestError("captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) - + return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -223,7 +222,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError("SAML 2.0 is disabled") + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } @@ -231,7 +230,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError("Manual linking is disabled") + return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c532a50ef..e591121f8 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -176,7 +176,7 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { w := httptest.NewRecorder() _, err := ts.API.verifyCaptcha(w, req) - require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).Code) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) }) } @@ -201,8 +201,8 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { }, }, { - desc: "Sms rate limit exceeded", - expectedErrorMsg: "429: Sms rate limit exceeded", + desc: "SMS rate limit exceeded", + expectedErrorMsg: "429: SMS rate limit exceeded", requestBody: map[string]interface{}{ "phone": "+1233456789", }, @@ -269,7 +269,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { { desc: "SAML not enabled", isEnabled: false, - expectedErr: notFoundError("SAML 2.0 is disabled"), + expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), }, { desc: "SAML enabled", diff --git a/internal/api/otp.go b/internal/api/otp.go index 0e437faa1..99b7bae32 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -34,10 +34,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError("Channel should only be specified with Phone OTP") + return badRequestError(ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -47,7 +47,7 @@ func (p *OtpParams) Validate() error { func (p *SmsParams) Validate(smsProvider string) error { if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } var err error @@ -80,7 +80,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return badRequestError("Signups not allowed for otp") + return unprocessableEntityError(ErrorCodeOTPDisabled, "Signups not allowed for otp") } else if err != nil { return err } @@ -91,7 +91,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { return a.SmsOtp(w, r) } - return otpError("unsupported_otp_type", "") + return badRequestError(ErrorCodeValidationFailed, "One of email or phone must be set") } type SmsOtpResponse struct { @@ -105,7 +105,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") + return badRequestError(ErrorCodePhoneProviderDisabled, "Unsupported phone provider") } var err error @@ -141,7 +141,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -152,7 +152,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) @@ -170,7 +171,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) return a.SmsOtp(w, r) @@ -191,11 +193,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(err) } mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { - return badRequestError("Error sending sms OTP: %v", serr) + return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } messageID = mID return nil diff --git a/internal/api/otp_test.go b/internal/api/otp_test.go index be3b18114..c72fbc361 100644 --- a/internal/api/otp_test.go +++ b/internal/api/otp_test.go @@ -80,8 +80,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -98,8 +99,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -115,10 +117,10 @@ func (ts *OtpTestSuite) TestOtpPKCE() { code int response map[string]interface{} }{ - http.StatusBadRequest, + http.StatusInternalServerError, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Error sending sms:", + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", }, }, }, @@ -182,8 +184,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Only an email address or phone number should be provided", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "Only an email address or phone number should be provided", }, }, }, @@ -200,8 +203,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": InvalidChannelError, + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": InvalidChannelError, }, }, }, @@ -244,15 +248,16 @@ func (ts *OtpTestSuite) TestNoSignupsForOtp() { ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusBadRequest, w.Code) + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) data := make(map[string]interface{}) require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) // response should be empty assert.Equal(ts.T(), data, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Signups not allowed for otp", + "code": float64(http.StatusUnprocessableEntity), + "error_code": ErrorCodeOTPDisabled, + "msg": "Signups not allowed for otp", }) } diff --git a/internal/api/phone.go b/internal/api/phone.go index cea94a944..f85caa6fd 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -24,7 +24,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", unprocessableEntityError("Invalid phone number format (E.164 required)") + return "", badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } return phone, nil } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 3c543634d..09810e288 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -177,8 +177,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "password": "testpassword", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending confirmation sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -190,8 +190,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "123456789", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -203,8 +203,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "111111111", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -214,8 +214,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { header: "", body: nil, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, } @@ -244,7 +244,12 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { require.Equal(ts.T(), c.expected["code"], w.Code) body := w.Body.String() - require.True(ts.T(), strings.Contains(body, c.expected["message"].(string))) + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) }) } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index 56ac1acb9..5ac75668d 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -21,9 +21,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -41,7 +41,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", badRequestError("No valid flow state found for user.") + return "", unprocessableEntityError(ErrorCodeFlowStateNotFound, "No valid flow state found for user.") } else if err != nil { return "", err } @@ -63,7 +63,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(InvalidPKCEParamsErrorMessage) + return badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index b62a51fc0..84b080070 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -23,16 +23,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return unprocessableEntityError("Reauthentication requires the user to have an email or a phone number") + return badRequestError(ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return badRequestError("Please verify your email first.") + return unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return badRequestError("Please verify your phone first.") + return unprocessableEntityError(ErrorCodePhoneNotConfirmed, "Please verify your phone first.") } } @@ -47,7 +47,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } else if phone != "" { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Failed to get SMS provider").WithInternalError(terr) } mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { @@ -60,7 +60,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + reason := ErrorCodeOverEmailSendRateLimit + if phone != "" { + reason = ErrorCodeOverSMSSendRateLimit + } + + return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") } return err } @@ -77,7 +82,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -87,7 +92,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -95,10 +100,10 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError("Reauthentication requires an email or a phone number") + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") } if !isValid { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { return internalServerError("Error during reauthentication").WithInternalError(err) diff --git a/internal/api/recover.go b/internal/api/recover.go index 77e3c068d..a3201852d 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -18,7 +18,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return badRequestError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error if p.Email, err = validateEmail(p.Email); err != nil { @@ -73,7 +73,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/resend.go b/internal/api/resend.go index cb8c4da24..fdad38c43 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -25,22 +25,22 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er break default: // type does not match one of the above - return badRequestError("Missing one of these types: signup, email_change, sms, phone_change") + return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == signupVerification { - return badRequestError("Type provided requires an email address") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError("Type provided requires a phone number") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } p.Email, err = validateEmail(p.Email) if err != nil { @@ -48,7 +48,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -56,7 +56,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else { // both email and phone are empty - return badRequestError("Missing email address or phone number") + return badRequestError(ErrorCodeValidationFailed, "Missing email address or phone number") } return nil } @@ -156,8 +156,13 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { + reason := ErrorCodeOverEmailSendRateLimit + if params.Type == smsVerification || params.Type == phoneChangeVerification { + reason = ErrorCodeOverSMSSendRateLimit + } + until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError("For security purposes, you can only request this once every %d seconds.", until) + return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/router.go b/internal/api/router.go index c2f06ae2e..70b41f22d 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -63,7 +63,7 @@ func handler(fn apiHandler) http.HandlerFunc { func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) } } @@ -78,7 +78,7 @@ func (m middlewareHandler) handler(next http.Handler) http.Handler { func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { ctx, err := m(w, r) if err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) return } if ctx != nil { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index a3932249a..d82117748 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -67,7 +67,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return badRequestError("SAML RelayState does not exist, try logging in again?") + return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { return err } @@ -77,7 +77,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return badRequestError("SAML RelayState has expired. Try loggin in again?") + return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -107,23 +107,23 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return badRequestError("SAML Artifact response can only be used with SP initiated flow") + return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return badRequestError("SAMLResponse is missing") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid Base64 string") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid XML SAML assertion") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -131,12 +131,12 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return badRequestError("SAML RelayState is not a valid UUID or URL") + return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return badRequestError("A SAML connection has not been established with this Identity Provider") + return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { return err } @@ -176,10 +176,10 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return badRequestError("SAML Assertion is not valid").WithInternalError(ire.PrivateErr) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(ire.PrivateErr) } - return badRequestError("SAML Assertion is not valid").WithInternalError(err) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -188,7 +188,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return badRequestError("SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -200,7 +200,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return badRequestError("SAML Assertion does not contain an email address") + return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } diff --git a/internal/api/signup.go b/internal/api/signup.go index 3d7a19000..5c7e588b8 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -34,21 +34,21 @@ func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { config := a.config if p.Password == "" { - return unprocessableEntityError("Signup requires a valid password") + return badRequestError(ErrorCodeValidationFailed, "Signup requires a valid password") } if err := a.checkPasswordStrength(ctx, p.Password); err != nil { return err } if p.Email != "" && p.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on signup.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") } if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } // PKCE not needed as phone signups already return access token in body if p.Phone != "" && p.CodeChallenge != "" { - return badRequestError("PKCE not supported for phone signups") + return badRequestError(ErrorCodeValidationFailed, "PKCE not supported for phone signups") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -114,7 +114,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} @@ -141,7 +141,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { switch params.Provider { case "email": if !config.External.Email.Enabled { - return badRequestError("Email signups are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email signups are disabled") } params.Email, err = validateEmail(params.Email) if err != nil { @@ -150,7 +150,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) case "phone": if !config.External.Phone.Enabled { - return badRequestError("Phone signups are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone signups are disabled") } params.Phone, err = validatePhone(params.Phone) if err != nil { @@ -158,7 +158,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) default: - return invalidSignupError(config) + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { @@ -241,7 +252,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if errors.Is(terr, MaxFrequencyLimitError) { now := time.Now() left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second - return tooManyRequestsError(fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) } return internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -265,10 +276,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } } @@ -277,10 +288,14 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") + reason := ErrorCodeOverEmailSendRateLimit + if params.Provider == "phone" { + reason = ErrorCodeOverSMSSendRateLimit } - if errors.Is(err, UserExistsError) { + + if errors.Is(err, MaxFrequencyLimitError) { + return tooManyRequestsError(reason, "For security purposes, you can only request this once every minute") + } else if errors.Is(err, UserExistsError) { err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ "provider": params.Provider, @@ -293,7 +308,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return badRequestError("User already registered") + return unprocessableEntityError(ErrorCodeUserAlreadyExists, "User already registered") } sanitizedUser, err := sanitizeUser(user, params) if err != nil { diff --git a/internal/api/sso.go b/internal/api/sso.go index 08ca4c616..e50d1a369 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -27,9 +27,9 @@ func (p *SingleSignOnParams) validate() (bool, error) { hasDomain := p.Domain != "" if hasProviderID && hasDomain { - return hasProviderID, badRequestError("Only one of provider_id or domain supported") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "Only one of provider_id or domain supported") } else if !hasProviderID && !hasDomain { - return hasProviderID, badRequestError("A provider_id or domain needs to be provided") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") } return hasProviderID, nil @@ -73,14 +73,14 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if hasProviderID { ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) if models.IsNotFoundError(err) { - return notFoundError("No such SSO provider") + return notFoundError(ErrorCodeSSOProviderNotFound, "No such SSO provider") } else if err != nil { return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) } } else { ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) if models.IsNotFoundError(err) { - return notFoundError("No SSO provider assigned for this domain") + return notFoundError(ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") } else if err != nil { return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) } diff --git a/internal/api/sso_test.go b/internal/api/sso_test.go index 7da8b6eb2..5fc46b2d0 100644 --- a/internal/api/sso_test.go +++ b/internal/api/sso_test.go @@ -277,7 +277,7 @@ func (ts *SSOTestSuite) TestAdminCreateSSOProvider() { }, }, { - StatusCode: http.StatusBadRequest, + StatusCode: http.StatusUnprocessableEntity, Request: map[string]interface{}{ "type": "saml", "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 0f966780e..6e52fc8ff 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -28,14 +28,14 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C idpID, err := uuid.FromString(idpParam) if err != nil { // idpParam is not UUIDv4 - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } // idpParam is a UUIDv4 provider, err := models.FindSSOProviderByID(db, idpID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } else { return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) } @@ -78,19 +78,19 @@ type CreateSSOProviderParams struct { func (p *CreateSSOProviderParams) validate(forUpdate bool) error { if !forUpdate && p.Type != "saml" { - return badRequestError("Only 'saml' supported for SSO provider type") + return badRequestError(ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") } else if p.MetadataURL != "" && p.MetadataXML != "" { - return badRequestError("Only one of metadata_xml or metadata_url needs to be set") + return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { - return badRequestError("Either metadata_xml or metadata_url must be set") + return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") } else if p.MetadataURL != "" { metadataURL, err := url.ParseRequestURI(p.MetadataURL) if err != nil { - return badRequestError("metadata_url is not a valid URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") } if metadataURL.Scheme != "https" { - return badRequestError("metadata_url is not a HTTPS URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") } } @@ -126,7 +126,7 @@ func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.E func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { if !utf8.Valid(rawMetadata) { - return nil, badRequestError("SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") } metadata, err := samlsp.ParseMetadata(rawMetadata) @@ -135,15 +135,15 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { } if metadata.EntityID == "" { - return nil, badRequestError("SAML Metadata does not contain an EntityID") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") } if len(metadata.IDPSSODescriptors) < 1 { - return nil, badRequestError("SAML Metadata does not contain any IDPSSODescriptor") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") } if len(metadata.IDPSSODescriptors) > 1 { - return nil, badRequestError("SAML Metadata contains multiple IDPSSODescriptors") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") } return metadata, nil @@ -152,7 +152,7 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, badRequestError("Unable to create a request to metadata_url").WithInternalError(err) + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) } req = req.WithContext(ctx) @@ -167,7 +167,7 @@ func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { defer utilities.SafeClose(resp.Body) if resp.StatusCode != http.StatusOK { - return nil, badRequestError("HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) } data, err := io.ReadAll(resp.Body) @@ -202,7 +202,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) } provider := &models.SSOProvider{ @@ -225,7 +225,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) } provider.SSODomains = append(provider.SSODomains, models.SSODomain{ @@ -280,7 +280,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if provider.SAMLProvider.EntityID != metadata.EntityID { - return badRequestError("SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) } if params.MetadataURL != "" { @@ -309,7 +309,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er if existingProvider.ID == provider.ID { keepDomains[domain] = true } else { - return badRequestError("SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) } } else { modified = true @@ -359,7 +359,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er return tx.Eager().Load(provider) }); err != nil { - return unprocessableEntityError("Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) } } diff --git a/internal/api/token.go b/internal/api/token.go index 2f6f9e3b2..df0292711 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -108,7 +108,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri config := a.config if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on login.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") } var user *models.User var grantParams models.GrantParams @@ -120,13 +120,13 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if params.Email != "" { provider = "email" if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) } else if params.Phone != "" { provider = "phone" if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } params.Phone = formatPhoneNumber(params.Phone) user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) @@ -178,7 +178,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return err } } - return forbiddenError(output.Message) + return oauthError("invalid_grant", InvalidLoginMessage) } } if !isValidPassword { @@ -230,24 +230,23 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) grantParams.FillGrantParams(r) params := &PKCEGrantParams{} - if err := retrieveRequestParams(r, params); err != nil { return err } if params.AuthCode == "" || params.CodeVerifier == "" { - return badRequestError("invalid request: both auth code and code verifier should be non-empty") + return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") } flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) // Sanity check in case user ID was not set properly if models.IsNotFoundError(err) || flowState.UserID == nil { - return forbiddenError("invalid flow state, no valid flow state found") + return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") } else if err != nil { return err } if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { - return forbiddenError("invalid flow state, flow state has expired") + return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") } user, err := models.FindUserByID(db, *flowState.UserID) @@ -255,7 +254,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return forbiddenError(err.Error()) + return badRequestError(ErrorBadCodeVerifier, err.Error()) } var token *AccessTokenResponse @@ -427,6 +426,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, if err != nil { return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) } + err = tx.Transaction(func(tx *storage.Connection) error { if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { return terr @@ -459,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) + tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 58e022afd..0574c3bb8 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -54,7 +54,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, nil, "", nil, badRequestError("Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -95,12 +95,12 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } } if cfg != nil && !cfg.Enabled { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) diff --git a/internal/api/token_test.go b/internal/api/token_test.go index b12a79a8e..53a492b11 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -343,7 +343,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusForbidden, w.Code) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) }) } } @@ -618,7 +618,7 @@ func (ts *TokenTestSuite) TestPasswordVerificationHook() { begin return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!'); end; $$ language plpgsql;`, - expectedCode: http.StatusForbidden, + expectedCode: http.StatusBadRequest, }, } for _, c := range cases { diff --git a/internal/api/user.go b/internal/api/user.go index 723521c17..9fe0dcef8 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -45,7 +45,7 @@ func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) p.Channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } } @@ -63,12 +63,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return internalServerError("Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return badRequestError(ErrorCodeValidationFailed, "Token audience doesn't match request audience") } user := getUser(ctx) @@ -96,7 +96,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.AppData != nil && !isAdmin(user, config) { if !isAdmin(user, config) { - return unauthorizedError("Updating app_metadata requires admin privileges") + return forbiddenError(ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") } } @@ -104,7 +104,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields := false updatingForbiddenFields = updatingForbiddenFields || (params.Password != nil && *params.Password != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating password of an anonymous user is not possible") + // CHECK + return unprocessableEntityError(ErrorCodeUnknown, "Updating password of an anonymous user is not possible") } } @@ -117,7 +118,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating email, phone, password of a SSO account only possible via SSO") + return unprocessableEntityError(ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") } } @@ -125,7 +126,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } @@ -133,7 +134,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(DuplicatePhoneMsg) + return unprocessableEntityError(ErrorCodePhoneExists, DuplicatePhoneMsg) } } @@ -143,7 +144,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { // we require reauthentication if the user hasn't signed in recently in the current session if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { if len(params.Nonce) == 0 { - return badRequestError("Password update requires reauthentication") + return badRequestError(ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") } if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { return err @@ -154,7 +155,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { password := *params.Password if password != "" { if user.EncryptedPassword != "" && user.Authenticate(ctx, password) { - return unprocessableEntityError("New password should be different from the old password.") + return unprocessableEntityError(ErrorCodeSamePassword, "New password should be different from the old password.") } } @@ -206,7 +207,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { externalURL := getExternalHost(ctx) if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending change email").WithInternalError(terr) } @@ -224,7 +225,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } else { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Error finding SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) diff --git a/internal/api/user_test.go b/internal/api/user_test.go index 2136f29c7..18ed6ec85 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -281,7 +281,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { nonce: "123456", requireReauthentication: true, sessionId: nil, - expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, }, { desc: "Valid password length", diff --git a/internal/api/verify.go b/internal/api/verify.go index 4af21e023..d6d5dc541 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -53,18 +53,18 @@ type VerifyParams struct { func (p *VerifyParams) Validate(r *http.Request) error { var err error if p.Type == "" { - return badRequestError("Verify requires a verification type") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type") } switch r.Method { case http.MethodGet: if p.Token == "" { - return badRequestError("Verify requires a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a token or a token hash") } // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) p.TokenHash = p.Token case http.MethodPost: if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { - return badRequestError("Verify requires either a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash") } if p.Token != "" { if isPhoneOtpVerification(p) { @@ -76,15 +76,15 @@ func (p *VerifyParams) Validate(r *http.Request) error { } else if isEmailOtpVerification(p) { p.Email, err = validateEmail(p.Email) if err != nil { - return unprocessableEntityError("Invalid email format").WithInternalError(err) + return unprocessableEntityError(ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) } p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) } else { - return badRequestError("Only an email address or phone number should be provided on verify") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") } } else if p.TokenHash != "" { if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { - return badRequestError("Only the token_hash and type should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only the token_hash and type should be provided") } } default: @@ -114,7 +114,8 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } return a.verifyPost(w, r, params) default: - return unprocessableEntityError("Only GET and POST methods are supported.") + // this should have been handled by Chi + panic("Only GET and POST methods allowed") } } @@ -165,7 +166,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return nil } default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -193,7 +194,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } } else if isPKCEFlow(flowType) { if authCode, terr = issueAuthCode(tx, user, authenticationMethod); terr != nil { - return badRequestError("No associated flow state found. %s", terr) + return badRequestError(ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) } } return nil @@ -266,7 +267,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP case smsVerification, phoneChangeVerification: user, terr = a.smsVerify(r, tx, user, params) default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -310,7 +311,8 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C // to present the user with a password set form password, err := password.Generate(64, 10, 0, false, true) if err != nil { - return nil, err + // password generation must succeed + panic(err) } if err := user.SetPassword(ctx, password); err != nil { @@ -433,14 +435,14 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, errorID := getRequestID(r.Context()) err.ErrorID = errorID log.WithError(err.Cause()).Info(err.Error()) - if str, ok := oauthErrorMap[err.Code]; ok { + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { hq.Set("error", str) q.Set("error", str) } - hq.Set("error_code", strconv.Itoa(err.Code)) + hq.Set("error_code", strconv.Itoa(err.HTTPStatus)) hq.Set("error_description", err.Message) - q.Set("error_code", strconv.Itoa(err.Code)) + q.Set("error_code", strconv.Itoa(err.HTTPStatus)) q.Set("error_description", err.Message) if flowType == models.PKCEFlow { // Additionally, may override existing error query param if set to PKCE. @@ -563,18 +565,18 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* case emailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) default: - return nil, badRequestError("Invalid email verification type") + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type") } if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) } return nil, internalServerError("Database error finding user from email link").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isExpired bool @@ -596,7 +598,7 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* } if isExpired { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalMessage("email link has expired") + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") } return user, nil @@ -625,13 +627,13 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil, internalServerError("Database error finding user").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isValid bool @@ -672,7 +674,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, } } if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return user, nil } @@ -680,7 +682,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, } if !isValid { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } return user, nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index c782446fb..ac6e79b71 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -280,9 +280,9 @@ func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) assert.Equal(ts.T(), "Email link is invalid or has expired", f.Get("error_description")) - assert.Equal(ts.T(), "unauthorized_client", f.Get("error")) + assert.Equal(ts.T(), "access_denied", f.Get("error")) } func (ts *VerifyTestSuite) TestInvalidOtp() { @@ -302,7 +302,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { } expectedResponse := ResponseBody{ - Code: http.StatusUnauthorized, + Code: http.StatusForbidden, Msg: "Token has expired or is invalid", } @@ -313,7 +313,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected ResponseBody }{ { - desc: "Expired Sms OTP", + desc: "Expired SMS OTP", sentTime: time.Now().Add(-48 * time.Hour), body: map[string]interface{}{ "type": smsVerification, @@ -323,7 +323,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected: expectedResponse, }, { - desc: "Invalid Sms OTP", + desc: "Invalid SMS OTP", sentTime: time.Now(), body: map[string]interface{}{ "type": smsVerification, @@ -760,7 +760,7 @@ func (ts *VerifyTestSuite) TestVerifyBannedUser() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) }) } } @@ -973,7 +973,7 @@ func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { "type": emailChangeVerification, "token_hash": currentEmailChangeToken, }, - expectedStatus: http.StatusUnauthorized, + expectedStatus: http.StatusForbidden, }, } for _, c := range cases { @@ -1102,7 +1102,7 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { for _, c := range cases { ts.Run(c.desc, func() { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rurl, err := ts.API.prepErrorRedirectURL(badRequestError(DefaultError), req, c.rurl, c.flowType) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected, rurl) }) @@ -1152,7 +1152,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Token: "some-token", }, method: http.MethodPost, - expected: badRequestError("Only an email address or phone number should be provided on verify"), + expected: badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), }, { desc: "Cannot send both TokenHash and Token", @@ -1162,7 +1162,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { TokenHash: "some-token-hash", }, method: http.MethodPost, - expected: badRequestError("Verify requires either a token or a token hash"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), }, { desc: "No verification type specified", @@ -1171,7 +1171,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Email: "email@example.com", }, method: http.MethodPost, - expected: badRequestError("Verify requires a verification type"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type"), }, }