Skip to content

Commit

Permalink
feat(saml): relaystate continuity fix + unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: sebferrer <[email protected]>

Co-authored-by: ThibaultHerard <[email protected]>
  • Loading branch information
sebferrer and ThibHrrd committed Feb 14, 2023
1 parent acf51ae commit 9cfed21
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 6 deletions.
241 changes: 241 additions & 0 deletions continuity/manager_relaystate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package continuity_test

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/ory/kratos/driver/config"

"github.com/ory/kratos/internal/testhelpers"

"github.com/ory/x/ioutilx"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/herodot"
"github.com/ory/x/logrusx"

"github.com/ory/kratos/continuity"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/x"
)

func TestManagerRelayState(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)

testhelpers.SetDefaultIdentitySchema(conf, "file://../test/stub/identity/empty.schema.json")
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://www.ory.sh")
i := identity.NewIdentity("")
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i))

var newServer = func(t *testing.T, p continuity.Manager, tc *persisterTestCase) *httptest.Server {
writer := herodot.NewJSONWriter(logrusx.New("", ""))
router := httprouter.New()
router.PUT("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := p.Pause(r.Context(), w, r, ps.ByName("name"), tc.ro...); err != nil {
writer.WriteError(w, r, err)
return
}
w.WriteHeader(http.StatusNoContent)
})

router.POST("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
relayState := r.URL.Query().Get("RelayState")

r.PostForm = make(url.Values)
r.PostForm.Set("RelayState", relayState)

c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.wo...)
if err != nil {
writer.WriteError(w, r, err)
return
}
writer.Write(w, r, c)
})

router.DELETE("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
relayState := r.URL.Query().Get("RelayState")

r.PostForm = make(url.Values)
r.PostForm.Set("RelayState", relayState)

err := p.Abort(r.Context(), w, r, ps.ByName("name"))
if err != nil {
writer.WriteError(w, r, err)
return
}
w.WriteHeader(http.StatusNoContent)
})

ts := httptest.NewServer(router)
t.Cleanup(func() {
ts.Close()
})
return ts
}

var newClient = func() *http.Client {
return &http.Client{Jar: x.EasyCookieJar(t, nil)}
}

p := reg.RelayStateContinuityManager()
cl := newClient()

t.Run("case=continue cookie persists with same http client", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
name := x.NewUUID().String()
href := ts.URL + "/" + name

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

req := x.NewTestHTTPRequest(t, "POST", href, nil)
require.Len(t, res.Cookies(), 1)

res, err = cl.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "name").String(), name)

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 1)
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

t.Run("case=continue cookie reconstructed and delivered with valid relaystate", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
name := x.NewUUID().String()
href := ts.URL + "/" + name

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

var relayState string

for _, c := range res.Cookies() {
relayState = c.Value
}

req := x.NewTestHTTPRequest(t, "POST", href+"?RelayState="+url.QueryEscape(relayState), nil)
require.Len(t, res.Cookies(), 1)

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "name").String(), name)

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 1)
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

t.Run("case=continue cookie not delivered with invalid relaystate", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
name := x.NewUUID().String()
href := ts.URL + "/" + name

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

var relayState string

for _, c := range res.Cookies() {
relayState = c.Value
relayState = strings.Replace(relayState, "a", "b", 1)
}
require.Len(t, res.Cookies(), 1)

req := x.NewTestHTTPRequest(t, "POST", href+"?RelayState="+url.QueryEscape(relayState), nil)

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)

body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
})

t.Run("case=continue cookie not delivered without relaystate", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
name := x.NewUUID().String()
href := ts.URL + "/" + name

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)
require.Len(t, res.Cookies(), 1)

req := x.NewTestHTTPRequest(t, "POST", href, nil)

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)

body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
})

t.Run("case=pause, abort, and continue session with failure", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
name := x.NewUUID().String()
href := ts.URL + "/" + name

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

req := x.NewTestHTTPRequest(t, "DELETE", href, nil)

res, err = cl.Do(req)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
require.Equal(t, http.StatusNoContent, res.StatusCode)

req = x.NewTestHTTPRequest(t, "POST", href, nil)

res, err = cl.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)

body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
})
}
8 changes: 2 additions & 6 deletions x/relaystate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ import (
func SessionGetStringRelayState(r *http.Request, s sessions.StoreExact, id string, key interface{}) (string, error) {

cipherRelayState := r.PostForm.Get("RelayState")
if cipherRelayState == "" {
return "", errors.New("The RelayState is empty or not exists")
}

// Reconstructs the cookie from the ciphered value
continuityCookie := &http.Cookie{
Expand All @@ -22,8 +19,7 @@ func SessionGetStringRelayState(r *http.Request, s sessions.StoreExact, id strin
MaxAge: 300,
}

r2 := r.Clone(r.Context())
r2.AddCookie(continuityCookie)
r.AddCookie(continuityCookie)

check := func(v map[interface{}]interface{}) (string, error) {
vv, ok := v[key]
Expand All @@ -37,7 +33,7 @@ func SessionGetStringRelayState(r *http.Request, s sessions.StoreExact, id strin
}

var exactErr error
sessionCookie, err := s.GetExact(r2, id, func(s *sessions.Session) bool {
sessionCookie, err := s.GetExact(r, id, func(s *sessions.Session) bool {
_, exactErr = check(s.Values)
return exactErr == nil
})
Expand Down

0 comments on commit 9cfed21

Please sign in to comment.