From e278b405e53f6e00f3012a49f14938443d6e7882 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:30:02 +0200 Subject: [PATCH] feat: graceful refresh token rotation (#3860) This patch adds a configuration flag which enables graceful refresh token rotation. Previously, refresh tokens could only be used once. On reuse, all tokens of that chain would be revoked. This is particularly challenging in environments, where it's difficult to make guarantees on synchronization. This could lead to refresh tokens being sent twice due to some parallel execution. To resolve this, refresh tokens can now be graceful by changing `oauth2.grant.refresh_token.grace_period=10s` (example value). During this time, a refresh token can be used multiple times to generate new refresh, ID, and access tokens. All tokens will correctly be invalidated, when the refresh token is re-used after the grace period expires, or when the delete consent endpoint is used. Closes #1831 #3770 --- driver/config/provider.go | 9 + driver/config/provider_test.go | 7 + internal/config/config.yaml | 12 ++ oauth2/fosite_store_helpers.go | 143 ++++++++++---- oauth2/oauth2_auth_code_test.go | 185 +++++++++++++++++- persistence/sql/migratest/migration_test.go | 20 +- ...efresh_token_in_grace_period_flag.down.sql | 1 + ..._refresh_token_in_grace_period_flag.up.sql | 1 + persistence/sql/persister.go | 1 + persistence/sql/persister_oauth2.go | 67 ++++++- spec/config.json | 19 +- x/fosite_storer.go | 5 +- 12 files changed, 409 insertions(+), 61 deletions(-) create mode 100644 persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql create mode 100644 persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql diff --git a/driver/config/provider.go b/driver/config/provider.go index a6abecaad0e..4f66e4448aa 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -102,6 +102,7 @@ const ( KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims" + KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101 KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" @@ -669,3 +670,11 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { return p.getProvider(ctx).String(key) + suffix } + +func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { + gracePeriod := p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) + if gracePeriod > time.Hour { + return time.Hour + } + return gracePeriod +} diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 8e5c44a9e2e..168ca81d69f 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -291,6 +291,13 @@ func TestViperProviderValidates(t *testing.T) { assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt(ctx)) assert.Equal(t, []string{"whatever"}, c.DefaultClientScope(ctx)) + // refresh + assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) + assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) + assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx)) + // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL(ctx)) diff --git a/internal/config/config.yaml b/internal/config/config.yaml index f3e8bff399c..49615d95966 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -402,6 +402,18 @@ oauth2: session: # store encrypted data in database, default true encrypt_at_rest: true + ## refresh_token_rotation + # By default Refresh Tokens are rotated and invalidated with each use. See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details + refresh_token_rotation: + # + ## grace_period + # + # Set the grace period for refresh tokens to be reused. Such reused tokens will result in multiple refresh tokens being issued. + # + # Examples: + # - 5s + # - 1m + grace_period: 0s # The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated, # for more information on this topic navigate to: diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 6b64d61991e..553a6bae62b 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -25,6 +25,7 @@ import ( "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/x" "github.com/ory/fosite/storage" @@ -225,16 +226,18 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) + t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) } func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { return func(t *testing.T) { - requestId := uuid.New() - mockRequestForeignKey(t, requestId, m) + ctx := context.Background() + requestID := uuid.New() + mockRequestForeignKey(t, requestID, m) cl := &client.Client{ID: "foobar"} fositeRequest := &fosite.Request{ - ID: requestId, + ID: requestID, Client: cl, RequestedAt: time.Now().UTC().Round(time.Second), Session: NewSession("bar"), @@ -242,15 +245,15 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. for i := 0; i < 4; i++ { signature := uuid.New() - err := m.OAuth2Storage().CreateRefreshTokenSession(context.TODO(), signature, fositeRequest) + err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateAccessTokenSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreatePKCERequestSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreatePKCERequestSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateAuthorizeCodeSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateAuthorizeCodeSession(ctx, signature, fositeRequest) assert.NoError(t, err) } } @@ -475,7 +478,7 @@ func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { m := x.OAuth2Storage() c := &client.Client{ID: "nil-request-client-id-123"} require.NoError(t, x.ClientManager().CreateClient(context.Background(), c)) - err := m.CreateAccessTokenSession(context.TODO(), "nil-request-id", &fosite.Request{ + err := m.CreateAccessTokenSession(context.Background(), "nil-request-id", &fosite.Request{ ID: "", RequestedAt: time.Now().UTC().Round(time.Second), Client: c, @@ -553,6 +556,63 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { } } +func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { + return func(t *testing.T) { + ctx := context.Background() + + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { + // SETUP + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + require.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + require.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // a revoked refresh token returns an error when getting the token again + assert.ErrorIs(t, err, fosite.ErrInactiveToken) + }) + + t.Run("refresh token enters grace period when configured,", func(t *testing.T) { + // SETUP + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m") + + // always reset back to the default + t.Cleanup(func() { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "0m") + }) + + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + require.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + + req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + + // ASSERT + // when grace period is configured the refresh token can be obtained within + // the grace period without error + assert.NoError(t, err) + + assert.Equal(t, defaultRequest.GetID(), req.GetID()) + }) + } + +} + func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() @@ -880,6 +940,7 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { + ctx := context.Background() grantManager := x.GrantManager() keyManager := x.KeyManager() grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage) @@ -902,28 +963,28 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - storedKeySet, err := grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + storedKeySet, err := grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.Len(t, storedKeySet.Keys, 0) - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - storedKeySet, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + storedKeySet, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) assert.Len(t, storedKeySet.Keys, 1) - storedKey, err := grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + storedKey, err := grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKey.KeyID) assert.Equal(t, publicKey.Use, storedKey.Use) assert.Equal(t, publicKey.Key, storedKey.Key) - storedScopes, err := grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID) + storedScopes, err := grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, grant.Scope, storedScopes) - storedKeySet, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + storedKeySet, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID) assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use) @@ -953,7 +1014,7 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig") require.NoError(t, err) - require.NoError(t, grantManager.CreateGrant(context.TODO(), trust.Grant{ + require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{ ID: uuid.New(), Issuer: issuer, Subject: subject, @@ -1011,22 +1072,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, grant.PublicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, grant.PublicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - err = grantManager.DeleteGrant(context.TODO(), grant.ID) + err = grantManager.DeleteGrant(ctx, grant.ID) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) assert.Error(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) }) @@ -1048,22 +1109,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - err = keyManager.DeleteKey(context.TODO(), issuer, publicKey.KeyID) + err = keyManager.DeleteKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) - _, err = grantManager.GetConcreteGrant(context.TODO(), grant.ID) + _, err = grantManager.GetConcreteGrant(ctx, grant.ID) assert.Error(t, err) }) @@ -1085,25 +1146,25 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) // All three get methods should only return the public key when using the valid subject - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.Error(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.Error(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.Empty(t, jwks.Keys) - jwks, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + jwks, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) @@ -1126,17 +1187,17 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) // All three get methods should always return the public key - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.NoError(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.NoError(t, err) - jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) @@ -1159,10 +1220,10 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - keys, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + keys, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) assert.Len(t, keys.Keys, 0) }) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index cb12f68c1ab..0d89e14ac9b 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { } assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64) - require.NoError(t, err) + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) } @@ -332,6 +333,186 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) + t.Run("case=graceful token rotation", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + t.Cleanup(func() { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + }) + + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) + + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } + + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { + start := time.Now() + token := issueTokens(t) + var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + a1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + b1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=first refresh from first refresh", func(t *testing.T) { + a2RefreshA = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=second refresh from first refresh", func(t *testing.T) { + a2RefreshB = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=first refresh from second refresh", func(t *testing.T) { + b2RefreshA = refreshTokens(t, b1Refresh) + }) + + t.Run("followup=second refresh from second refresh", func(t *testing.T) { + b2RefreshB = refreshTokens(t, b1Refresh) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, token := range []*oauth2.Token{ + a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + } + }) + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { // Make sure we test against all crypto suites that we advertise. cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 1fa0ce3836d..8564cfab969 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) { ss := []flow.LoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 17, len(ss)) for _, s := range ss { @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) { ss := []consent.ForcedObfuscatedLoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 13, len(ss)) for _, s := range ss { @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) { lrs := []flow.LogoutRequest{} - c.All(&lrs) + require.NoError(t, c.All(&lrs)) require.Equal(t, 7, len(lrs)) for _, s := range lrs { @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) { bjtis := []oauth2.BlacklistedJTI{} - c.All(&bjtis) + require.NoError(t, c.All(&bjtis)) require.Equal(t, 1, len(bjtis)) for _, bjti := range bjtis { testhelpersuuid.AssertUUID(t, bjti.NID) @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_access", func(t *testing.T) { as := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)) require.Equal(t, 13, len(as)) for _, a := range as { @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_refresh", func(t *testing.T) { rs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs) + require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs)) require.Equal(t, 13, len(rs)) for _, r := range rs { @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_code", func(t *testing.T) { cs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)) require.Equal(t, 13, len(cs)) for _, c := range cs { @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_oidc", func(t *testing.T) { os := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)) require.Equal(t, 13, len(os)) for _, o := range os { @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_pkce", func(t *testing.T) { ps := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)) require.Equal(t, 11, len(ps)) for _, p := range ps { @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) { t.Run("case=networks", func(t *testing.T) { ns := []networkx.Network{} - c.RawQuery("SELECT * FROM networks").All(&ns) + require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns)) require.Equal(t, 1, len(ns)) for _, n := range ns { testhelpersuuid.AssertUUID(t, n.ID) diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql new file mode 100644 index 00000000000..a30a127e902 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN first_used_at; diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql new file mode 100644 index 00000000000..8ae823047f7 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD first_used_at TIMESTAMP DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 93649fc46ef..98161c55cf6 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -60,6 +60,7 @@ type ( contextx.Provider x.RegistryLogger x.TracingProvider + config.Provider } ) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 6e1336b80de..083e67ac5da 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -58,6 +58,10 @@ type ( // InternalExpiresAt denormalizes the expiry from the session to additionally store it as a row. InternalExpiresAt sqlxx.NullTime `db:"expires_at" json:"-"` } + OAuth2RefreshTable struct { + OAuth2RequestSQL + FirstUsedAt sql.NullTime `db:"first_used_at"` + } ) const ( @@ -72,6 +76,10 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } +func (r OAuth2RefreshTable) TableName() string { + return "hydra_oauth2_refresh" +} + func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { @@ -122,6 +130,24 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, }, nil } +func (p *Persister) marshalSession(ctx context.Context, session fosite.Session) ([]byte, error) { + sessionBytes, err := json.Marshal(session) + if err != nil { + return nil, err + } + + if !p.config.EncryptSessionData(ctx) { + return sessionBytes, nil + } + + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, sessionBytes, nil) + if err != nil { + return nil, err + } + + return []byte(ciphertext), nil +} + func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest") defer otelx.End(span, &err) @@ -429,7 +455,34 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh) + + r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } + + fositeRequest, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + + if r.Active { + return fositeRequest, nil + } + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { + if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) + } + + r.Active = true // We set active to true because we are in the grace period. + return r.toRequest(ctx, session, p) // And re-generate the request + } + + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) } func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { @@ -486,7 +539,17 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { diff --git a/spec/config.json b/spec/config.json index 2445cbc6a24..72f81534c66 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1068,6 +1068,21 @@ "type": "object", "additionalProperties": false, "properties": { + "refresh_token": { + "type": "object", + "properties": { + "grace_period": { + "title": "Refresh Token Rotation Grace Period", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "default": "0s", + "allOf": [ + { + "$ref": "#/definitions/duration" + } + ] + } + } + }, "jwt": { "type": "object", "additionalProperties": false, @@ -1122,8 +1137,8 @@ } ] } - } - }, + } + }, "secrets": { "type": "object", "additionalProperties": false, diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 23654c519b9..546cfc98870 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -18,16 +18,13 @@ import ( type FositeStorer interface { fosite.Storage oauth2.CoreStorage + oauth2.TokenRevocationStorage openid.OpenIDConnectRequestStorage pkce.PKCERequestStorage rfc7523.RFC7523KeyStorage verifiable.NonceManager oauth2.ResourceOwnerPasswordCredentialsGrantStorage - RevokeRefreshToken(ctx context.Context, requestID string) error - - RevokeAccessToken(ctx context.Context, requestID string) error - // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error