diff --git a/internal/testhelpers/selfservice_verification.go b/internal/testhelpers/selfservice_verification.go index 0071b52aa9ec..1cfca6fb9b5d 100644 --- a/internal/testhelpers/selfservice_verification.go +++ b/internal/testhelpers/selfservice_verification.go @@ -49,10 +49,14 @@ func GetRecoveryFlow(t *testing.T, client *http.Client, ts *httptest.Server) *kr return rs } -func InitializeRecoveryFlowViaBrowser(t *testing.T, client *http.Client, isSPA bool, ts *httptest.Server) *kratos.SelfServiceRecoveryFlow { +func InitializeRecoveryFlowViaBrowser(t *testing.T, client *http.Client, isSPA bool, ts *httptest.Server, values url.Values) *kratos.SelfServiceRecoveryFlow { publicClient := NewSDKCustomClient(ts, client) - req, err := http.NewRequest("GET", ts.URL+recovery.RouteInitBrowserFlow, nil) + u := ts.URL + recovery.RouteInitBrowserFlow + if values != nil { + u += "?" + values.Encode() + } + req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) if isSPA { @@ -120,7 +124,7 @@ func SubmitRecoveryForm( if isAPI { f = InitializeRecoveryFlowViaAPI(t, hc, publicTS) } else { - f = InitializeRecoveryFlowViaBrowser(t, hc, isSPA, publicTS) + f = InitializeRecoveryFlowViaBrowser(t, hc, isSPA, publicTS, nil) } time.Sleep(time.Millisecond) // add a bit of delay to allow `1ns` to time out. diff --git a/selfservice/flow/recovery/handler.go b/selfservice/flow/recovery/handler.go index de1032d1ba60..6cbbe53c80c8 100644 --- a/selfservice/flow/recovery/handler.go +++ b/selfservice/flow/recovery/handler.go @@ -8,9 +8,10 @@ import ( "github.com/ory/kratos/schema" - "github.com/ory/kratos/ui/node" "github.com/ory/x/sqlcon" + "github.com/ory/kratos/ui/node" + "github.com/ory/herodot" "github.com/julienschmidt/httprouter" diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index 343070d403bb..f933997dea60 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -10,6 +10,11 @@ import ( "github.com/pkg/errors" "github.com/ory/herodot" + "github.com/ory/x/decoderx" + "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" + "github.com/ory/x/urlx" + "github.com/ory/kratos/identity" "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" @@ -19,10 +24,6 @@ import ( "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" - "github.com/ory/x/decoderx" - "github.com/ory/x/sqlcon" - "github.com/ory/x/sqlxx" - "github.com/ory/x/urlx" ) const ( @@ -271,6 +272,7 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, if err != nil { return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err) } + sf.RequestURL = f.RequestURL if err := s.d.RecoveryExecutor().PostRecoveryHook(w, r, f, sess); err != nil { return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err) diff --git a/selfservice/strategy/link/strategy_recovery_test.go b/selfservice/strategy/link/strategy_recovery_test.go index 37eebaac8c12..c3b1935e2c1c 100644 --- a/selfservice/strategy/link/strategy_recovery_test.go +++ b/selfservice/strategy/link/strategy_recovery_test.go @@ -382,7 +382,7 @@ func TestRecovery(t *testing.T) { }) t.Run("description=should recover an account", func(t *testing.T) { - var check = func(t *testing.T, recoverySubmissionResponse, recoveryEmail string) { + var check = func(t *testing.T, recoverySubmissionResponse, recoveryEmail, returnTo string) { addr, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, recoveryEmail) assert.NoError(t, err) assert.False(t, addr.Verified) @@ -413,6 +413,7 @@ func TestRecovery(t *testing.T) { body := ioutilx.MustReadAll(res.Body) assert.Equal(t, text.NewRecoverySuccessful(time.Now().Add(time.Hour)).Text, gjson.GetBytes(body, "ui.messages.0.text").String()) + assert.Equal(t, returnTo, gjson.GetBytes(body, "return_to").String()) addr, err = reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, recoveryEmail) assert.NoError(t, err) @@ -433,23 +434,46 @@ func TestRecovery(t *testing.T) { createIdentityToRecover(email) check(t, expectSuccess(t, nil, false, false, func(v url.Values) { v.Set("email", email) - }), email) + }), email, "") }) - t.Run("type=spa", func(t *testing.T) { + t.Run("type=browser set return_to", func(t *testing.T) { email := "recoverme2@ory.sh" + returnTo := "https://www.ory.sh" + createIdentityToRecover(email) + + hc := testhelpers.NewClientWithCookies(t) + hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper + + f := testhelpers.InitializeRecoveryFlowViaBrowser(t, hc, false, public, url.Values{"return_to": []string{returnTo}}) + + time.Sleep(time.Millisecond) // add a bit of delay to allow `1ns` to time out. + + formPayload := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes) + formPayload.Set("email", email) + + b, res := testhelpers.RecoveryMakeRequest(t, false, f, hc, testhelpers.EncodeFormAsJSON(t, false, formPayload)) + assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", b) + expectedURL := testhelpers.ExpectURL(false, public.URL+recovery.RouteSubmitFlow, conf.SelfServiceFlowRecoveryUI().String()) + assert.Contains(t, res.Request.URL.String(), expectedURL, "%+v\n\t%s", res.Request, b) + + check(t, b, email, returnTo) + }) + + t.Run("type=spa", func(t *testing.T) { + email := "recoverme3@ory.sh" createIdentityToRecover(email) check(t, expectSuccess(t, nil, true, true, func(v url.Values) { v.Set("email", email) - }), email) + }), email, "") }) t.Run("type=api", func(t *testing.T) { - email := "recoverme3@ory.sh" + email := "recoverme4@ory.sh" createIdentityToRecover(email) check(t, expectSuccess(t, nil, true, false, func(v url.Values) { v.Set("email", email) - }), email) + }), email, "") }) }) @@ -490,7 +514,7 @@ func TestRecovery(t *testing.T) { t.Run("description=should not be able to use an invalid link", func(t *testing.T) { c := testhelpers.NewClientWithCookies(t) - f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, public) + f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, public, nil) res, err := c.Get(f.Ui.Action + "&token=i-do-not-exist") require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) @@ -504,7 +528,7 @@ func TestRecovery(t *testing.T) { }) t.Run("description=should not be able to use an outdated link", func(t *testing.T) { - recoveryEmail := "recoverme4@ory.sh" + recoveryEmail := "recoverme5@ory.sh" createIdentityToRecover(recoveryEmail) conf.MustSet(config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200) t.Cleanup(func() { @@ -530,7 +554,7 @@ func TestRecovery(t *testing.T) { }) t.Run("description=should not be able to use an outdated flow", func(t *testing.T) { - recoveryEmail := "recoverme5@ory.sh" + recoveryEmail := "recoverme6@ory.sh" createIdentityToRecover(recoveryEmail) conf.MustSet(config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200) t.Cleanup(func() { @@ -601,7 +625,7 @@ func TestDisabledEndpoint(t *testing.T) { c := testhelpers.NewClientWithCookies(t) t.Run("description=can not recover an account by get request when link method is disabled", func(t *testing.T) { - f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, publicTS) + f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, publicTS, nil) u := publicTS.URL + recovery.RouteSubmitFlow + "?flow=" + f.Id + "&token=endpoint-disabled" res, err := c.Get(u) require.NoError(t, err) @@ -612,7 +636,7 @@ func TestDisabledEndpoint(t *testing.T) { }) t.Run("description=can not recover an account by post request when link method is disabled", func(t *testing.T) { - f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, publicTS) + f := testhelpers.InitializeRecoveryFlowViaBrowser(t, c, false, publicTS, nil) u := publicTS.URL + recovery.RouteSubmitFlow + "?flow=" + f.Id res, err := c.PostForm(u, url.Values{"email": {"email@ory.sh"}, "method": {"link"}}) require.NoError(t, err) diff --git a/selfservice/strategy/link/strategy_test.go b/selfservice/strategy/link/strategy_test.go index 95f1dd5a4e50..37286e32e762 100644 --- a/selfservice/strategy/link/strategy_test.go +++ b/selfservice/strategy/link/strategy_test.go @@ -11,6 +11,7 @@ import ( func initViper(t *testing.T, c *config.Config) { c.MustSet(config.ViperKeyDefaultIdentitySchemaURL, "file://./stub/default.schema.json") c.MustSet(config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh") + c.MustSet(config.ViperKeyURLsWhitelistedReturnToDomains, []string{"https://www.ory.sh"}) c.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+identity.CredentialsTypePassword.String()+".enabled", true) c.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+recovery.StrategyRecoveryLinkName+".enabled", true) c.MustSet(config.ViperKeySelfServiceRecoveryEnabled, true)