Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: propagate logout to identity provider #3596

Merged
merged 7 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type (
// Cookie management
GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error)
CreateLoginSession(ctx context.Context, session *flow.LoginSession) error
DeleteLoginSession(ctx context.Context, id string) error
DeleteLoginSession(ctx context.Context, id string) (deletedSession *flow.LoginSession, err error)
RevokeSubjectLoginSession(ctx context.Context, user string) error
ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authTime time.Time, subject string, remember bool) error

Expand Down
14 changes: 10 additions & 4 deletions consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,12 @@ func TestHelperNID(r interface {
require.NoError(t, err)
require.Error(t, t2InvalidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.NoError(t, t1ValidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.Error(t, t2InvalidNID.DeleteLoginSession(ctx, testLS.ID))
require.NoError(t, t1ValidNID.DeleteLoginSession(ctx, testLS.ID))
ls, err := t2InvalidNID.DeleteLoginSession(ctx, testLS.ID)
require.Error(t, err)
assert.Nil(t, ls)
ls, err = t1ValidNID.DeleteLoginSession(ctx, testLS.ID)
require.NoError(t, err)
assert.Equal(t, testLS.ID, ls.ID)
}
}

Expand Down Expand Up @@ -429,8 +433,9 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
},
} {
t.Run("case=delete-get-"+tc.id, func(t *testing.T) {
err := m.DeleteLoginSession(ctx, tc.id)
ls, err := m.DeleteLoginSession(ctx, tc.id)
require.NoError(t, err)
assert.EqualValues(t, tc.id, ls.ID)

_, err = m.GetRememberedLoginSession(ctx, nil, tc.id)
require.Error(t, err)
Expand Down Expand Up @@ -1083,7 +1088,8 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
require.NoError(t, err)
assert.EqualValues(t, expected.ID, result.ID)

require.NoError(t, m.DeleteLoginSession(ctx, s.ID))
_, err = m.DeleteLoginSession(ctx, s.ID)
require.NoError(t, err)

result, err = m.GetConsentRequest(ctx, expected.ID)
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions consent/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/ory/fosite/handler/openid"
"github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/x"
)

Expand All @@ -17,6 +18,7 @@ type InternalRegistry interface {
x.RegistryCookieStore
x.RegistryLogger
x.HTTPClientProvider
kratos.Provider
Registry
client.Registry

Expand Down
25 changes: 18 additions & 7 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ func (s *DefaultStrategy) revokeAuthenticationSession(ctx context.Context, w htt
return nil
}

return s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)
_, err = s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)

return err
}

func (s *DefaultStrategy) revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, ss sessions.Store) (string, error) {
Expand Down Expand Up @@ -458,6 +460,7 @@ func (s *DefaultStrategy) verifyAuthentication(
return nil, fosite.ErrAccessDenied.WithHint("The login session cookie was not found or malformed.")
}

loginSession.IdentityProviderSessionID = f.IdentityProviderSessionID
if err := s.r.ConsentManager().ConfirmLoginSession(ctx, loginSession, sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil {
return nil, err
}
Expand Down Expand Up @@ -731,7 +734,8 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
return urls, nil
}

func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
ctx := r.Context()
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
Expand Down Expand Up @@ -1000,8 +1004,9 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
}

func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Context, r *http.Request, subject string, sid string) error {
if err := s.executeBackChannelLogout(r.Context(), r, subject, sid); err != nil {
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
ctx := r.Context()
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
return err
}

Expand All @@ -1010,10 +1015,16 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Con
//
// executeBackChannelLogout only fails on system errors so not on URL errors, so this should be fine
// even if an upstream URL fails!
if err := s.r.ConsentManager().DeleteLoginSession(r.Context(), sid); errors.Is(err, sqlcon.ErrNoRows) {
if session, err := s.r.ConsentManager().DeleteLoginSession(ctx, sid); errors.Is(err, sqlcon.ErrNoRows) {
// This is ok (session probably already revoked), do nothing!
} else if err != nil {
return err
} else {
innerErr := s.r.Kratos().DisableSession(ctx, session.IdentityProviderSessionID.String())
if innerErr != nil {
s.r.Logger().WithError(innerErr).WithField("sid", sid).Error("Unable to revoke session in ORY Kratos.")
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
}
// We don't return the error here because we don't want to break the logout flow if Kratos is down.
}

return nil
Expand Down Expand Up @@ -1068,7 +1079,7 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
return nil, err
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1105,7 +1116,7 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
return lsErr
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, loginSession.Subject, sid); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
return err
}

Expand Down
14 changes: 5 additions & 9 deletions consent/strategy_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@ import (
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"testing"

hydra "github.com/ory/hydra-client-go/v2"

"github.com/stretchr/testify/require"

"github.com/ory/fosite/token/jwt"
"github.com/ory/x/urlx"

"net/url"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/fosite/token/jwt"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/client"
. "github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/x/ioutilx"
"github.com/ory/x/urlx"
)

func checkAndAcceptLoginHandler(t *testing.T, apiClient *hydra.APIClient, subject string, cb func(*testing.T, *hydra.OAuth2LoginRequest, error) hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
Expand Down
22 changes: 21 additions & 1 deletion consent/strategy_logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/pointerx"

"github.com/stretchr/testify/assert"
Expand All @@ -35,9 +36,11 @@ import (

func TestLogoutFlows(t *testing.T) {
ctx := context.Background()
fakeKratos := kratos.NewFake()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour)
reg.WithKratos(fakeKratos)

defaultRedirectedMessage := "redirected to default server"
postLogoutCallback := func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -181,7 +184,10 @@ func TestLogoutFlows(t *testing.T) {
checkAndAcceptLoginHandler(t, adminApi, subject, func(t *testing.T, res *hydra.OAuth2LoginRequest, err error) hydra.AcceptOAuth2LoginRequest {
require.NoError(t, err)
//res.Payload.SessionID
return hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true)}
return hydra.AcceptOAuth2LoginRequest{
Remember: pointerx.Ptr(true),
IdentityProviderSessionId: pointerx.Ptr(kratos.FakeSessionID),
}
}),
checkAndAcceptConsentHandler(t, adminApi, func(t *testing.T, res *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
require.NoError(t, err)
Expand Down Expand Up @@ -476,6 +482,7 @@ func TestLogoutFlows(t *testing.T) {
})

t.Run("case=should return to default post logout because session was revoked in browser context", func(t *testing.T) {
fakeKratos.Reset()
c := createSampleClient(t)
sid := make(chan string)
acceptLoginAsAndWatchSid(t, subject, sid)
Expand Down Expand Up @@ -518,9 +525,13 @@ func TestLogoutFlows(t *testing.T) {
assert.NotEmpty(t, res.Request.URL.Query().Get("code"))

wg.Wait()

assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should execute backchannel logout in headless flow with sid", func(t *testing.T) {
fakeKratos.Reset()
numSidConsumers := 2
sid := make(chan string, numSidConsumers)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, true, numSidConsumers)
Expand All @@ -535,22 +546,31 @@ func TestLogoutFlows(t *testing.T) {
logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})

backChannelWG.Wait() // we want to ensure that all back channels have been called!
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should logout in headless flow with non-existing sid", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectNoContent(t, browserWithoutSession, url.Values{"sid": {"non-existing-sid"}})
assert.False(t, fakeKratos.DisableSessionWasCalled)
})

t.Run("case=should logout in headless flow with session that has remember=false", func(t *testing.T) {
fakeKratos.Reset()
sid := make(chan string)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, false, 1)

c := createSampleClient(t)

logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should fail headless logout because neither sid nor subject were provided", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectError(t, browserWithoutSession, url.Values{}, `Either 'subject' or 'sid' query parameters need to be defined.`)
assert.False(t, fakeKratos.DisableSessionWasCalled)
})
}
13 changes: 11 additions & 2 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
KeyPublicURL = "urls.self.public"
KeyAdminURL = "urls.self.admin"
KeyIssuerURL = "urls.self.issuer"
KeyIdentityProviderAdminURL = "urls.identity_provider.admin_base_url"
KeyAccessTokenStrategy = "strategies.access_token"
KeyJWTScopeClaimStrategy = "strategies.jwt.scope_claim"
KeyDBIgnoreUnknownTableColumns = "db.ignore_unknown_table_columns"
Expand All @@ -104,8 +105,10 @@ const (

const DSNMemory = "memory"

var _ hasherx.PBKDF2Configurator = (*DefaultProvider)(nil)
var _ hasherx.BCryptConfigurator = (*DefaultProvider)(nil)
var (
_ hasherx.PBKDF2Configurator = (*DefaultProvider)(nil)
_ hasherx.BCryptConfigurator = (*DefaultProvider)(nil)
)

type DefaultProvider struct {
l *logrusx.Logger
Expand Down Expand Up @@ -393,6 +396,12 @@ func (p *DefaultProvider) IssuerURL(ctx context.Context) *url.URL {
)
}

func (p *DefaultProvider) KratosAdminURL(ctx context.Context) (*url.URL, bool) {
u := p.getProvider(ctx).RequestURIF(KeyIdentityProviderAdminURL, nil)

return u, u != nil
}

func (p *DefaultProvider) OAuth2ClientRegistrationURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyOAuth2ClientRegistrationURL, new(url.URL))
}
Expand Down
4 changes: 4 additions & 0 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"go.opentelemetry.io/otel/trace"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/httprouterx"
"github.com/ory/x/popx"

Expand Down Expand Up @@ -54,6 +55,7 @@ type Registry interface {
WithLogger(l *logrusx.Logger) Registry
WithTracer(t trace.Tracer) Registry
WithTracerWrapper(TracerWrapper) Registry
WithKratos(k kratos.Client) Registry
x.HTTPClientProvider
GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy

Expand All @@ -72,6 +74,8 @@ type Registry interface {
x.TracingProvider
FlowCipher() *aead.XChaCha20Poly1305

kratos.Provider

RegisterRoutes(ctx context.Context, admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic)
ClientHandler() *client.Handler
KeyHandler() *jwk.Handler
Expand Down
14 changes: 14 additions & 0 deletions driver/registry_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/hsm"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
Expand Down Expand Up @@ -88,6 +89,7 @@ type RegistryBase struct {
hmacs *foauth2.HMACSHAStrategy
fc *fositex.Config
publicCORS *cors.Cors
kratos kratos.Client
}

func (m *RegistryBase) GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy {
Expand Down Expand Up @@ -201,6 +203,11 @@ func (m *RegistryBase) WithTracerWrapper(wrapper TracerWrapper) Registry {
return m.r
}

func (m *RegistryBase) WithKratos(k kratos.Client) Registry {
m.kratos = k
return m.r
}

func (m *RegistryBase) Logger() *logrusx.Logger {
if m.l == nil {
m.l = logrusx.New("Ory Hydra", m.BuildVersion())
Expand Down Expand Up @@ -552,3 +559,10 @@ func (m *RegistryBase) HSMContext() hsm.Context {
func (m *RegistrySQL) ClientAuthenticator() x.ClientAuthenticator {
return m.OAuth2Provider().(*fosite.Fosite)
}

func (m *RegistryBase) Kratos() kratos.Client {
if m.kratos == nil {
m.kratos = kratos.New(m)
}
return m.kratos
}
17 changes: 12 additions & 5 deletions flow/consent_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ type OAuth2RedirectTo struct {

// swagger:ignore
type LoginSession struct {
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
Remember bool `db:"remember"`
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
IdentityProviderSessionID sqlxx.NullString `db:"identity_provider_session_id"`
Remember bool `db:"remember"`
}

func (LoginSession) TableName() string {
Expand Down Expand Up @@ -292,6 +293,12 @@ type HandledLoginRequest struct {
// required: true
Subject string `json:"subject"`

// IdentityProviderSessionID is the session ID of the end-user that authenticated.
// If specified, we will use this value to propagate the logout.
//
// required: false
IdentityProviderSessionID string `json:"identity_provider_session_id,omitempty"`

// ForceSubjectIdentifier forces the "pairwise" user ID of the end-user that authenticated. The "pairwise" user ID refers to the
// (Pairwise Identifier Algorithm)[http://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg] of the OpenID
// Connect specification. It allows you to set an obfuscated subject ("user") identifier that is unique to the client.
Expand Down
Loading