diff --git a/driver/config/provider.go b/driver/config/provider.go index a6abecaad0e..4f66e4448aa 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -102,6 +102,7 @@ const ( KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims" + KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101 KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" @@ -669,3 +670,11 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { return p.getProvider(ctx).String(key) + suffix } + +func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { + gracePeriod := p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) + if gracePeriod > time.Hour { + return time.Hour + } + return gracePeriod +} diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 8e5c44a9e2e..168ca81d69f 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -291,6 +291,13 @@ func TestViperProviderValidates(t *testing.T) { assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt(ctx)) assert.Equal(t, []string{"whatever"}, c.DefaultClientScope(ctx)) + // refresh + assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) + assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx)) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) + assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx)) + // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL(ctx)) diff --git a/internal/config/config.yaml b/internal/config/config.yaml index f3e8bff399c..49615d95966 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -402,6 +402,18 @@ oauth2: session: # store encrypted data in database, default true encrypt_at_rest: true + ## refresh_token_rotation + # By default Refresh Tokens are rotated and invalidated with each use. See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details + refresh_token_rotation: + # + ## grace_period + # + # Set the grace period for refresh tokens to be reused. Such reused tokens will result in multiple refresh tokens being issued. + # + # Examples: + # - 5s + # - 1m + grace_period: 0s # The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated, # for more information on this topic navigate to: diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 6b64d61991e..553a6bae62b 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -25,6 +25,7 @@ import ( "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/x" "github.com/ory/fosite/storage" @@ -225,16 +226,18 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) + t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) } func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { return func(t *testing.T) { - requestId := uuid.New() - mockRequestForeignKey(t, requestId, m) + ctx := context.Background() + requestID := uuid.New() + mockRequestForeignKey(t, requestID, m) cl := &client.Client{ID: "foobar"} fositeRequest := &fosite.Request{ - ID: requestId, + ID: requestID, Client: cl, RequestedAt: time.Now().UTC().Round(time.Second), Session: NewSession("bar"), @@ -242,15 +245,15 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. for i := 0; i < 4; i++ { signature := uuid.New() - err := m.OAuth2Storage().CreateRefreshTokenSession(context.TODO(), signature, fositeRequest) + err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateAccessTokenSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreatePKCERequestSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreatePKCERequestSession(ctx, signature, fositeRequest) assert.NoError(t, err) - err = m.OAuth2Storage().CreateAuthorizeCodeSession(context.TODO(), signature, fositeRequest) + err = m.OAuth2Storage().CreateAuthorizeCodeSession(ctx, signature, fositeRequest) assert.NoError(t, err) } } @@ -475,7 +478,7 @@ func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { m := x.OAuth2Storage() c := &client.Client{ID: "nil-request-client-id-123"} require.NoError(t, x.ClientManager().CreateClient(context.Background(), c)) - err := m.CreateAccessTokenSession(context.TODO(), "nil-request-id", &fosite.Request{ + err := m.CreateAccessTokenSession(context.Background(), "nil-request-id", &fosite.Request{ ID: "", RequestedAt: time.Now().UTC().Round(time.Second), Client: c, @@ -553,6 +556,63 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { } } +func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { + return func(t *testing.T) { + ctx := context.Background() + + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { + // SETUP + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + require.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + require.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // a revoked refresh token returns an error when getting the token again + assert.ErrorIs(t, err, fosite.ErrInactiveToken) + }) + + t.Run("refresh token enters grace period when configured,", func(t *testing.T) { + // SETUP + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m") + + // always reset back to the default + t.Cleanup(func() { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "0m") + }) + + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + require.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + + req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + + // ASSERT + // when grace period is configured the refresh token can be obtained within + // the grace period without error + assert.NoError(t, err) + + assert.Equal(t, defaultRequest.GetID(), req.GetID()) + }) + } + +} + func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() @@ -880,6 +940,7 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { + ctx := context.Background() grantManager := x.GrantManager() keyManager := x.KeyManager() grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage) @@ -902,28 +963,28 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - storedKeySet, err := grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + storedKeySet, err := grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.Len(t, storedKeySet.Keys, 0) - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - storedKeySet, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + storedKeySet, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) assert.Len(t, storedKeySet.Keys, 1) - storedKey, err := grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + storedKey, err := grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKey.KeyID) assert.Equal(t, publicKey.Use, storedKey.Use) assert.Equal(t, publicKey.Key, storedKey.Key) - storedScopes, err := grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID) + storedScopes, err := grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, grant.Scope, storedScopes) - storedKeySet, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + storedKeySet, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID) assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use) @@ -953,7 +1014,7 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig") require.NoError(t, err) - require.NoError(t, grantManager.CreateGrant(context.TODO(), trust.Grant{ + require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{ ID: uuid.New(), Issuer: issuer, Subject: subject, @@ -1011,22 +1072,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, grant.PublicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, grant.PublicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - err = grantManager.DeleteGrant(context.TODO(), grant.ID) + err = grantManager.DeleteGrant(ctx, grant.ID) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) assert.Error(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) }) @@ -1048,22 +1109,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - err = keyManager.DeleteKey(context.TODO(), issuer, publicKey.KeyID) + err = keyManager.DeleteKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) - _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID) + _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) - _, err = grantManager.GetConcreteGrant(context.TODO(), grant.ID) + _, err = grantManager.GetConcreteGrant(ctx, grant.ID) assert.Error(t, err) }) @@ -1085,25 +1146,25 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) // All three get methods should only return the public key when using the valid subject - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.Error(t, err) - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.Error(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) - jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.Empty(t, jwks.Keys) - jwks, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject) + jwks, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) @@ -1126,17 +1187,17 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) // All three get methods should always return the public key - _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID) + _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.NoError(t, err) - _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID) + _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.NoError(t, err) - jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) @@ -1159,10 +1220,10 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0), } - err = grantManager.CreateGrant(context.TODO(), grant, publicKey) + err = grantManager.CreateGrant(ctx, grant, publicKey) require.NoError(t, err) - keys, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3") + keys, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) assert.Len(t, keys.Keys, 0) }) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index cb12f68c1ab..0d89e14ac9b 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { } assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64) - require.NoError(t, err) + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) } @@ -332,6 +333,186 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) + t.Run("case=graceful token rotation", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + t.Cleanup(func() { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + }) + + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) + + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } + + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { + start := time.Now() + token := issueTokens(t) + var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + a1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + b1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=first refresh from first refresh", func(t *testing.T) { + a2RefreshA = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=second refresh from first refresh", func(t *testing.T) { + a2RefreshB = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=first refresh from second refresh", func(t *testing.T) { + b2RefreshA = refreshTokens(t, b1Refresh) + }) + + t.Run("followup=second refresh from second refresh", func(t *testing.T) { + b2RefreshB = refreshTokens(t, b1Refresh) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, token := range []*oauth2.Token{ + a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + } + }) + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { // Make sure we test against all crypto suites that we advertise. cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 1fa0ce3836d..8564cfab969 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) { ss := []flow.LoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 17, len(ss)) for _, s := range ss { @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) { ss := []consent.ForcedObfuscatedLoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 13, len(ss)) for _, s := range ss { @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) { lrs := []flow.LogoutRequest{} - c.All(&lrs) + require.NoError(t, c.All(&lrs)) require.Equal(t, 7, len(lrs)) for _, s := range lrs { @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) { bjtis := []oauth2.BlacklistedJTI{} - c.All(&bjtis) + require.NoError(t, c.All(&bjtis)) require.Equal(t, 1, len(bjtis)) for _, bjti := range bjtis { testhelpersuuid.AssertUUID(t, bjti.NID) @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_access", func(t *testing.T) { as := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)) require.Equal(t, 13, len(as)) for _, a := range as { @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_refresh", func(t *testing.T) { rs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs) + require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs)) require.Equal(t, 13, len(rs)) for _, r := range rs { @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_code", func(t *testing.T) { cs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)) require.Equal(t, 13, len(cs)) for _, c := range cs { @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_oidc", func(t *testing.T) { os := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)) require.Equal(t, 13, len(os)) for _, o := range os { @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_pkce", func(t *testing.T) { ps := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)) require.Equal(t, 11, len(ps)) for _, p := range ps { @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) { t.Run("case=networks", func(t *testing.T) { ns := []networkx.Network{} - c.RawQuery("SELECT * FROM networks").All(&ns) + require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns)) require.Equal(t, 1, len(ns)) for _, n := range ns { testhelpersuuid.AssertUUID(t, n.ID) diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql new file mode 100644 index 00000000000..a30a127e902 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN first_used_at; diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql new file mode 100644 index 00000000000..8ae823047f7 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD first_used_at TIMESTAMP DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 93649fc46ef..98161c55cf6 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -60,6 +60,7 @@ type ( contextx.Provider x.RegistryLogger x.TracingProvider + config.Provider } ) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 6e1336b80de..083e67ac5da 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -58,6 +58,10 @@ type ( // InternalExpiresAt denormalizes the expiry from the session to additionally store it as a row. InternalExpiresAt sqlxx.NullTime `db:"expires_at" json:"-"` } + OAuth2RefreshTable struct { + OAuth2RequestSQL + FirstUsedAt sql.NullTime `db:"first_used_at"` + } ) const ( @@ -72,6 +76,10 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } +func (r OAuth2RefreshTable) TableName() string { + return "hydra_oauth2_refresh" +} + func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { @@ -122,6 +130,24 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, }, nil } +func (p *Persister) marshalSession(ctx context.Context, session fosite.Session) ([]byte, error) { + sessionBytes, err := json.Marshal(session) + if err != nil { + return nil, err + } + + if !p.config.EncryptSessionData(ctx) { + return sessionBytes, nil + } + + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, sessionBytes, nil) + if err != nil { + return nil, err + } + + return []byte(ciphertext), nil +} + func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest") defer otelx.End(span, &err) @@ -429,7 +455,34 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh) + + r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } + + fositeRequest, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + + if r.Active { + return fositeRequest, nil + } + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { + if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) + } + + r.Active = true // We set active to true because we are in the grace period. + return r.toRequest(ctx, session, p) // And re-generate the request + } + + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) } func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { @@ -486,7 +539,17 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { diff --git a/spec/config.json b/spec/config.json index 2445cbc6a24..72f81534c66 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1068,6 +1068,21 @@ "type": "object", "additionalProperties": false, "properties": { + "refresh_token": { + "type": "object", + "properties": { + "grace_period": { + "title": "Refresh Token Rotation Grace Period", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "default": "0s", + "allOf": [ + { + "$ref": "#/definitions/duration" + } + ] + } + } + }, "jwt": { "type": "object", "additionalProperties": false, @@ -1122,8 +1137,8 @@ } ] } - } - }, + } + }, "secrets": { "type": "object", "additionalProperties": false, diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 23654c519b9..546cfc98870 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -18,16 +18,13 @@ import ( type FositeStorer interface { fosite.Storage oauth2.CoreStorage + oauth2.TokenRevocationStorage openid.OpenIDConnectRequestStorage pkce.PKCERequestStorage rfc7523.RFC7523KeyStorage verifiable.NonceManager oauth2.ResourceOwnerPasswordCredentialsGrantStorage - RevokeRefreshToken(ctx context.Context, requestID string) error - - RevokeAccessToken(ctx context.Context, requestID string) error - // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error