From fb74bf89d4425314c3261b12fbc5fa523e1e0e57 Mon Sep 17 00:00:00 2001 From: Krzysztof Bogacki Date: Wed, 2 Aug 2023 15:32:19 +0200 Subject: [PATCH] feat: redirect to OIDC providers only once in registration flows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test(e2e): ensure there is only one OIDC redirect Co-authored-by: Jakub FijaƂkowski --- .../strategy/oidc/strategy_registration.go | 38 +++++++++++++++ .../oidc/registration/success.spec.ts | 46 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index f756e8caf156..0b71061ef4a0 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -42,6 +42,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{ type MetadataType string +type OIDCProviderData struct { + Provider string `json:"provider"` + Tokens *identity.CredentialsOIDCEncryptedTokens `json:"tokens"` + Claims Claims `json:"claims"` +} + type VerifiedAddress struct { Value string `json:"value"` Via identity.VerifiableAddressType `json:"via"` @@ -52,6 +58,8 @@ const ( PublicMetadata MetadataType = "identity.metadata_public" AdminMetadata MetadataType = "identity.metadata_admin" + + InternalContextKeyProviderData = "provider_data" ) func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) { @@ -213,6 +221,25 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return errors.WithStack(flow.ErrCompletedByStrategy) } + if oidcProviderData := gjson.GetBytes(f.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)); oidcProviderData.IsObject() { + var providerData OIDCProviderData + if err := json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err))) + } + if pid != providerData.Provider { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider))) + } + _, err = s.processRegistration(ctx, w, r, f, providerData.Tokens, &providerData.Claims, provider, &AuthCodeContainer{ + FlowID: f.ID.String(), + Traits: p.Traits, + TransientPayload: f.TransientPayload, + }, "") + if err != nil { + return s.handleError(ctx, w, r, f, pid, nil, err) + } + return errors.WithStack(flow.ErrCompletedByStrategy) + } + state := generateState(f.ID.String()) if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode { state.setCode(code.InitCode) @@ -309,6 +336,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, nil } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData { + if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Tokens: token, Claims: *claims}); err == nil { + rf.InternalContext = internalContext + } + } + fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())), fetcher.WithCache(jsonnetCache, 60*time.Minute)) jsonnetMapperSnippet, err := fetch.FetchContext(r.Context(), provider.Config().Mapper) if err != nil { @@ -347,6 +381,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } + if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil { + rf.InternalContext = internalContext + } + return nil, nil } diff --git a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts index 370d5f0ff253..f0f9ec67159c 100644 --- a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts @@ -103,6 +103,52 @@ context("Social Sign Up Successes", () => { }) }) + it("should redirect to oidc provider only once", () => { + const email = gen.email() + + cy.registerOidc({ + app, + email, + expectSession: false, + route: registration, + }) + + cy.get(appPrefix(app) + '[name="traits.email"]').should( + "have.value", + email, + ) + + cy.get('[name="traits.consent"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.newsletter"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.website"]').type(website) + + cy.intercept("GET", "http://*/oauth2/auth*", { + forceNetworkError: true, + }).as("additionalRedirect") + + cy.triggerOidc(app) + + cy.get("@additionalRedirect").should("not.exist") + + cy.location("pathname").should((loc) => { + expect(loc).to.be.oneOf([ + "/welcome", + "/", + "/sessions", + "/verification", + ]) + }) + + cy.getSession().should((session) => { + shouldSession(email)(session) + expect(session.identity.traits.consent).to.equal(true) + }) + }) + it("should pass transient_payload to webhook", () => { testFlowWebhook( (hooks) =>