From e39473f702e0fdbb4ace2582fc6564f86bca8bed Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Fri, 1 Nov 2024 12:06:24 +0100 Subject: [PATCH] fix: don't return on logout, make it idempotent --- internal/api/auth.go | 10 +++++++--- internal/api/auth_test.go | 2 +- internal/api/mfa_test.go | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/internal/api/auth.go b/internal/api/auth.go index b03767f02..6062238c2 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/go-chi/chi/v5" "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/conf" @@ -25,7 +26,10 @@ func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (con return ctx, err } - ctx, err = a.maybeLoadUserOrSession(ctx) + routeContext := chi.RouteContext(ctx) + skipSessionMissingError := routeContext != nil && routeContext.RouteMethod == http.MethodPost && routeContext.RoutePath == "/logout" + + ctx, err = a.maybeLoadUserOrSession(ctx, skipSessionMissingError) if err != nil { return ctx, err } @@ -94,7 +98,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return withToken(ctx, token), nil } -func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) { +func (a *API) maybeLoadUserOrSession(ctx context.Context, skipSessionMissingError bool) (context.Context, error) { db := a.db.WithContext(ctx) claims := getClaims(ctx) @@ -130,7 +134,7 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro } session, err = models.FindSessionByID(db, sessionId, false) if err != nil { - if models.IsNotFoundError(err) { + if models.IsNotFoundError(err) && !skipSessionMissingError { return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) } return ctx, err diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 71afe6638..0e7ebe8cf 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -271,7 +271,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { ctx, err := ts.API.parseJWTClaims(userJwt, req) require.NoError(ts.T(), err) - ctx, err = ts.API.maybeLoadUserOrSession(ctx) + ctx, err = ts.API.maybeLoadUserOrSession(ctx, false) if c.ExpectedError != nil { require.Equal(ts.T(), c.ExpectedError.Error(), err.Error()) } else { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 4e2a79758..542e50f53 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -676,7 +676,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { ctx, err := ts.API.parseJWTClaims(data.Token, req) require.NoError(ts.T(), err) - ctx, err = ts.API.maybeLoadUserOrSession(ctx) + ctx, err = ts.API.maybeLoadUserOrSession(ctx, false) require.NoError(ts.T(), err) require.True(ts.T(), getSession(ctx).IsAAL2()) }