Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: accept recovery link from authenticated users (#1077) #2195

Merged
merged 2 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ replace (
// official SDK, allowing for the Ory CLI to consume Ory Kratos' CLI commands.
github.com/ory/kratos-client-go => ./internal/httpclient
go.mongodb.org/mongo-driver => go.mongodb.org/mongo-driver v1.4.6
golang.org/x/sys => golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac
golang.org/x/sys => golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8
gopkg.in/DataDog/dd-trace-go.v1 => gopkg.in/DataDog/dd-trace-go.v1 v1.27.1-0.20201005154917-54b73b3e126a
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2108,8 +2108,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 h1:OH54vjqzRWmbJ62fjuhxy7AxFFgoHN0/DPc/UrL8cAs=
golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20191110171634-ad39bd3f0407/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
Expand Down
26 changes: 22 additions & 4 deletions internal/testhelpers/handler_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ type mockDeps interface {
}

func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i))

MockSetSessionWithIdentity(t, reg, conf, i)(w, r, ps)
}
}

func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, conf *config.Config, i *identity.Identity) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
activeSession, _ := session.NewActiveSession(i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
if aal := r.URL.Query().Get("set_aal"); len(aal) > 0 {
activeSession.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel(aal)
Expand All @@ -56,10 +62,21 @@ func MockGetSession(t *testing.T, reg mockDeps) httprouter.Handle {
}

func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request) ([]byte, *http.Response) {
return MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, req, NewClientWithCookies(t))
}

func MockMakeAuthenticatedRequestWithClient(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client) ([]byte, *http.Response) {
return MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, router, req, client, nil)
}

func MockMakeAuthenticatedRequestWithClientAndID(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client, id *identity.Identity) ([]byte, *http.Response) {
set := "/" + uuid.New().String() + "/set"
router.GET(set, MockSetSession(t, reg, conf))
if id == nil {
router.GET(set, MockSetSession(t, reg, conf))
} else {
router.GET(set, MockSetSessionWithIdentity(t, reg, conf, id))
}

client := NewClientWithCookies(t)
MockHydrateCookieClient(t, client, "http://"+req.URL.Host+set+"?"+req.URL.Query().Encode())

res, err := client.Do(req)
Expand Down Expand Up @@ -94,6 +111,7 @@ func MockHydrateCookieClient(t *testing.T, c *http.Client, u string) {
res, err := c.Get(u)
require.NoError(t, err)
defer res.Body.Close()
body := x.MustReadAll(res.Body)
assert.EqualValues(t, http.StatusOK, res.StatusCode)

var found bool
Expand All @@ -102,7 +120,7 @@ func MockHydrateCookieClient(t *testing.T, c *http.Client, u string) {
found = true
}
}
require.True(t, found)
require.True(t, found, "got body: %s\ngot url: %s", body, res.Request.URL.String())
}

func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identity.Identity) (httprouter.Handle, *session.Session) {
Expand Down
1 change: 1 addition & 0 deletions internal/testhelpers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func NewKratosServers(t *testing.T) (public, admin *httptest.Server) {
public = httptest.NewServer(x.NewRouterPublic())
admin = httptest.NewServer(x.NewRouterAdmin())

public.URL = strings.Replace(public.URL, "127.0.0.1", "localhost", -1)
t.Cleanup(public.Close)
t.Cleanup(admin.Close)
return
Expand Down
34 changes: 31 additions & 3 deletions internal/testhelpers/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package testhelpers
import (
"context"
"net/http"
"strings"
"testing"
"time"

"github.com/ory/nosurf"

"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"

Expand Down Expand Up @@ -52,14 +55,39 @@ func maybePersistSession(t *testing.T, reg *driver.RegistryDefault, sess *sessio
func NewHTTPClientWithSessionCookie(t *testing.T, reg *driver.RegistryDefault, sess *session.Session) *http.Client {
maybePersistSession(t, reg, sess)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), w, r, sess))
}))
})

if _, ok := reg.CSRFHandler().(*nosurf.CSRFHandler); ok {
handler = nosurf.New(handler)
}

ts := httptest.NewServer(handler)
defer ts.Close()

c := NewClientWithCookies(t)
MockHydrateCookieClient(t, c, ts.URL)
return c
}

func NewHTTPClientWithSessionCookieLocalhost(t *testing.T, reg *driver.RegistryDefault, sess *session.Session) *http.Client {
maybePersistSession(t, reg, sess)

var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), w, r, sess))
})

if _, ok := reg.CSRFHandler().(*nosurf.CSRFHandler); ok {
handler = nosurf.New(handler)
}

ts := httptest.NewServer(handler)
defer ts.Close()

c := NewClientWithCookies(t)

// This should work for other test servers as well because cookies ignore ports.
ts.URL = strings.Replace(ts.URL, "127.0.0.1", "localhost", 1)
MockHydrateCookieClient(t, c, ts.URL)
return c
}
Expand Down
56 changes: 56 additions & 0 deletions selfservice/flow/nosurf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package flow

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"

"github.com/ory/nosurf"
)

func TestGetCSRFToken(t *testing.T) {
noToken := &mockReg{
presentToken: "",
regeneratedToken: "regenerated",
}

tokenPresent := &mockReg{
presentToken: "existing",
regeneratedToken: "regenerated",
}

t.Run("case=no token, browser flow", func(t *testing.T) {
assert.Equal(t, "regenerated", GetCSRFToken(noToken, nil, nil, TypeBrowser))
})

t.Run("case=token present, browser flow", func(t *testing.T) {
assert.Equal(t, "existing", GetCSRFToken(tokenPresent, nil, nil, TypeBrowser))
})

t.Run("case=no token, api flow", func(t *testing.T) {
assert.Equal(t, "", GetCSRFToken(noToken, nil, nil, TypeAPI))
})

t.Run("case=token present, api flow", func(t *testing.T) {
assert.Equal(t, "existing", GetCSRFToken(tokenPresent, nil, nil, TypeAPI))
})
}

type mockReg struct {
presentToken, regeneratedToken string

nosurf.Handler
}

func (m *mockReg) GenerateCSRFToken(*http.Request) string {
return m.presentToken
}

func (m *mockReg) CSRFHandler() nosurf.Handler {
return m
}

func (m *mockReg) RegenerateToken(http.ResponseWriter, *http.Request) string {
return m.regeneratedToken
}
4 changes: 2 additions & 2 deletions selfservice/flow/recovery/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {

public.GET(RouteGetFlow, h.fetch)

public.GET(RouteSubmitFlow, h.d.SessionHandler().IsNotAuthenticated(h.submitFlow, redirect))
public.POST(RouteSubmitFlow, h.d.SessionHandler().IsNotAuthenticated(h.submitFlow, redirect))
public.GET(RouteSubmitFlow, h.submitFlow)
public.POST(RouteSubmitFlow, h.submitFlow)
}

func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
Expand Down
11 changes: 10 additions & 1 deletion selfservice/strategy/link/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
return s.recoveryUseToken(w, r, body)
}

if _, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil {
if x.IsJSONRequest(r) {
session.RespondWithJSONErrorOnAuthenticated(s.d.Writer(), recovery.ErrAlreadyLoggedIn)(w, r, nil)
} else {
session.RedirectOnAuthenticated(s.d)(w, r, nil)
}
return errors.WithStack(flow.ErrCompletedByStrategy)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), s.RecoveryStrategyID(), body.Method, s.d); err != nil {
return s.HandleRecoveryError(w, r, nil, body, err)
}
Expand Down Expand Up @@ -254,7 +263,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, f *recovery.Flow, id *identity.Identity) error {
f.UI.Messages.Clear()
f.State = recovery.StatePassedChallenge
f.SetCSRFToken(flow.GetCSRFToken(s.d, w, r, f.Type))
f.SetCSRFToken(s.d.CSRFHandler().RegenerateToken(w, r))
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
f.RecoveredIdentityID = uuid.NullUUID{
UUID: id.ID,
Valid: true,
Expand Down
100 changes: 90 additions & 10 deletions selfservice/strategy/link/strategy_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func TestRecovery(t *testing.T) {
_ = testhelpers.NewSettingsUIFlowEchoServer(t, reg)
_ = testhelpers.NewErrorTestServer(t, reg)

public, _ := testhelpers.NewKratosServerWithCSRF(t, reg)
public, _, publicRouter, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)

var createIdentityToRecover = func(email string) *identity.Identity {
var id = &identity.Identity{
Expand Down Expand Up @@ -298,7 +298,59 @@ func TestRecovery(t *testing.T) {
check(t, expectValidationError(t, nil, true, false, values), email)
})
}
})

t.Run("description=should try to submit the form while authenticated", func(t *testing.T) {
run := func(t *testing.T, flow string) {
isAPI := flow == "api"
isSPA := flow == "spa"
hc := testhelpers.NewDebugClient(t)
if !isAPI {
hc = testhelpers.NewClientWithCookies(t)
hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper
}

var f *kratos.SelfServiceRecoveryFlow
if isAPI {
f = testhelpers.InitializeRecoveryFlowViaAPI(t, hc, public)
} else {
f = testhelpers.InitializeRecoveryFlowViaBrowser(t, hc, isSPA, public, nil)
}

v := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
v.Set("email", "[email protected]")
v.Set("method", "link")

authClient := testhelpers.NewHTTPClientWithArbitrarySessionToken(t, reg)
if isAPI {
s, err := session.NewActiveSession(
&identity.Identity{ID: x.NewUUID(), State: identity.StateActive},
testhelpers.NewSessionLifespanProvider(time.Hour),
time.Now(),
identity.CredentialsTypePassword,
identity.AuthenticatorAssuranceLevel1,
)
require.NoError(t, err)
authClient = testhelpers.NewHTTPClientWithSessionCookieLocalhost(t, reg, s)
}

body, res := testhelpers.RecoveryMakeRequest(t, isAPI || isSPA, f, authClient, testhelpers.EncodeFormAsJSON(t, isAPI || isSPA, v))

if isAPI || isSPA {
assert.EqualValues(t, http.StatusBadRequest, res.StatusCode, "%s", body)
assert.Contains(t, res.Request.URL.String(), recovery.RouteSubmitFlow, "%+v\n\t%s", res.Request, body)
assertx.EqualAsJSONExcept(t, recovery.ErrAlreadyLoggedIn, json.RawMessage(gjson.Get(body, "error").Raw), nil)
} else {
assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", body)
assert.Contains(t, res.Request.URL.String(), conf.SelfServiceBrowserDefaultReturnTo().String(), "%+v\n\t%s", res.Request, body)
}
}

for _, f := range []string{"browser", "spa", "api"} {
t.Run("type="+f, func(t *testing.T) {
run(t, f)
})
}
})

t.Run("description=should try to recover an email that does not exist", func(t *testing.T) {
Expand Down Expand Up @@ -480,38 +532,66 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should recover an account and set the csrf cookies", func(t *testing.T) {
recoveryEmail := "[email protected]"

var check = func(t *testing.T, actual string) {
var check = func(t *testing.T, actual, recoveryEmail string, cl *http.Client, do func(*http.Client, *http.Request) (*http.Response, error)) {
message := testhelpers.CourierExpectMessage(t, reg, recoveryEmail, "Recover access to your account")
recoveryLink := testhelpers.CourierExpectLinkInMessage(t, message, 1)

cl := testhelpers.NewClientWithCookies(t)
cl.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
res, err := cl.Get(recoveryLink)
res, err := do(cl, x.NewTestHTTPRequest(t, "GET", recoveryLink, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusSeeOther, res.StatusCode)
require.Len(t, cl.Jar.Cookies(urlx.ParseOrPanic(public.URL)), 2)
cookies := spew.Sdump(cl.Jar.Cookies(urlx.ParseOrPanic(public.URL)))
assert.Contains(t, cookies, x.CSRFTokenName)
assert.Contains(t, cookies, "ory_kratos_session")
returnTo, err := res.Location()
require.NoError(t, err)
assert.Contains(t, returnTo.String(), conf.SelfServiceFlowSettingsUI().String(), "we end up at the settings screen")

rl := urlx.ParseOrPanic(recoveryLink)
actualRes, err := cl.Get(public.URL + recovery.RouteGetFlow + "?id=" + rl.Query().Get("flow"))
require.NoError(t, err)
body := x.MustReadAll(actualRes.Body)
require.NoError(t, actualRes.Body.Close())
assert.Equal(t, http.StatusOK, actualRes.StatusCode, "%s", body)
assert.Equal(t, string(recovery.StatePassedChallenge), gjson.GetBytes(body, "state").String(), "%s", body)
}

var values = func(v url.Values) {
v.Set("email", recoveryEmail)
}
email := x.NewUUID().String() + "@ory.sh"
id := createIdentityToRecover(email)

check(t, expectSuccess(t, nil, false, false, values))
t.Run("case=unauthenticated", func(t *testing.T) {
var values = func(v url.Values) {
v.Set("email", email)
}
check(t, expectSuccess(t, nil, false, false, values), email, testhelpers.NewClientWithCookies(t), (*http.Client).Do)
})

t.Run("case=already logged into another account", func(t *testing.T) {
var values = func(v url.Values) {
v.Set("email", email)
}

check(t, expectSuccess(t, nil, false, false, values), email, testhelpers.NewClientWithCookies(t), func(cl *http.Client, req *http.Request) (*http.Response, error) {
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, publicRouter.Router, req, cl)
return res, nil
})
})

t.Run("case=already logged into the account", func(t *testing.T) {
var values = func(v url.Values) {
v.Set("email", email)
}

cl := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, reg, id)
check(t, expectSuccess(t, nil, false, false, values), email, cl, func(_ *http.Client, req *http.Request) (*http.Response, error) {
_, res := testhelpers.MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, publicRouter.Router, req, cl, id)
return res, nil
})
})
})

t.Run("description=should recover and invalidate all other sessions if hook is set", func(t *testing.T) {
Expand Down
Loading