-
Notifications
You must be signed in to change notification settings - Fork 359
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
authenticator: Add cookie session authenticator (#211)
- Loading branch information
Showing
6 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
package authn | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/pkg/errors" | ||
|
||
"github.com/ory/oathkeeper/driver/configuration" | ||
"github.com/ory/oathkeeper/helper" | ||
"github.com/ory/oathkeeper/pipeline" | ||
) | ||
|
||
type AuthenticatorCookieSessionFilter struct { | ||
} | ||
|
||
type AuthenticatorCookieSessionConfiguration struct { | ||
Only []string `json:"only"` | ||
CheckSessionURL string `json:"check_session_url"` | ||
} | ||
|
||
type AuthenticatorCookieSession struct { | ||
c configuration.Provider | ||
} | ||
|
||
func NewAuthenticatorCookieSession(c configuration.Provider) *AuthenticatorCookieSession { | ||
return &AuthenticatorCookieSession{ | ||
c: c, | ||
} | ||
} | ||
|
||
func (a *AuthenticatorCookieSession) GetID() string { | ||
return "cookie_session" | ||
} | ||
|
||
func (a *AuthenticatorCookieSession) Validate() error { | ||
if !a.c.AuthenticatorCookieSessionIsEnabled() { | ||
return errors.WithStack(ErrAuthenticatorNotEnabled.WithReasonf(`Authenticator "%s" is disabled per configuration.`, a.GetID())) | ||
} | ||
|
||
if a.c.AuthenticatorCookieSessionCheckSessionURL() == "" { | ||
return errors.WithStack(ErrAuthenticatorNotEnabled.WithReasonf( | ||
`Configuration for authenticator "%s" did not specify any values for configuration key "%s" and is thus disabled.`, | ||
a.GetID(), | ||
configuration.ViperKeyAuthenticatorCookieSessionCheckSessionURL, | ||
)) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (a *AuthenticatorCookieSession) Authenticate(r *http.Request, config json.RawMessage, _ pipeline.Rule) (*AuthenticationSession, error) { | ||
var cf AuthenticatorCookieSessionConfiguration | ||
if len(config) == 0 { | ||
config = []byte("{}") | ||
} | ||
d := json.NewDecoder(bytes.NewBuffer(config)) | ||
d.DisallowUnknownFields() | ||
if err := d.Decode(&cf); err != nil { | ||
return nil, errors.WithStack(err) | ||
} | ||
|
||
only := cf.Only | ||
if len(only) == 0 { | ||
only = a.c.AuthenticatorCookieSessionOnly() | ||
} | ||
if !cookieSessionResponsible(r, only) { | ||
return nil, errors.WithStack(ErrAuthenticatorNotResponsible) | ||
} | ||
|
||
origin := cf.CheckSessionURL | ||
if origin == "" { | ||
origin = a.c.AuthenticatorCookieSessionCheckSessionURL() | ||
} | ||
|
||
body, err := forwardRequestToSessionStore(r, origin) | ||
if err != nil { | ||
return nil, helper.ErrForbidden.WithReason(err.Error()).WithTrace(err) | ||
} | ||
|
||
var session struct { | ||
Subject string `json:"subject"` | ||
Extra map[string]interface{} `json:"extra"` | ||
} | ||
err = json.Unmarshal(body, &session) | ||
if err != nil { | ||
return nil, helper.ErrForbidden.WithReason(err.Error()).WithTrace(err) | ||
} | ||
|
||
return &AuthenticationSession{ | ||
Subject: session.Subject, | ||
Extra: session.Extra, | ||
}, nil | ||
} | ||
|
||
func cookieSessionResponsible(r *http.Request, only []string) bool { | ||
if len(only) == 0 { | ||
return true | ||
} | ||
for _, cookieName := range only { | ||
if _, err := r.Cookie(cookieName); err == nil { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
func forwardRequestToSessionStore(r *http.Request, checkSessionURL string) (json.RawMessage, error) { | ||
reqUrl, err := url.Parse(checkSessionURL) | ||
if err != nil { | ||
return nil, helper.ErrForbidden.WithReason(err.Error()).WithTrace(err) | ||
} | ||
reqUrl.Path = r.URL.Path | ||
|
||
res, err := http.DefaultClient.Do(&http.Request{ | ||
Method: r.Method, | ||
URL: reqUrl, | ||
Header: r.Header, | ||
}) | ||
if err != nil { | ||
return nil, helper.ErrForbidden.WithReason(err.Error()).WithTrace(err) | ||
} | ||
|
||
if res.StatusCode == 200 { | ||
body, err := ioutil.ReadAll(res.Body) | ||
if err != nil { | ||
return json.RawMessage{}, err | ||
} | ||
return json.RawMessage(body), nil | ||
} else { | ||
return json.RawMessage{}, errors.WithStack(helper.ErrUnauthorized) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package authn_test | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strconv" | ||
"testing" | ||
|
||
"github.com/ory/oathkeeper/internal" | ||
. "github.com/ory/oathkeeper/pipeline/authn" | ||
"github.com/pkg/errors" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAuthenticatorCookieSession(t *testing.T) { | ||
conf := internal.NewConfigurationWithDefaults() | ||
reg := internal.NewRegistry(conf) | ||
|
||
pipelineAuthenticator, err := reg.PipelineAuthenticator("cookie_session") | ||
require.NoError(t, err) | ||
|
||
t.Run("method=authenticate", func(t *testing.T) { | ||
t.Run("description=should fail because session store returned 400", func(t *testing.T) { | ||
testServer, _ := makeServer(400, `{}`) | ||
_, err := pipelineAuthenticator.Authenticate( | ||
makeRequest("GET", "/", map[string]string{"sessionid": "zyx"}, ""), | ||
json.RawMessage(fmt.Sprintf(`{"check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
require.Error(t, err, "%#v", errors.Cause(err)) | ||
}) | ||
|
||
t.Run("description=should pass because session store returned 200", func(t *testing.T) { | ||
testServer, _ := makeServer(200, `{"subject": "123", "extra": {"foo": "bar"}}`) | ||
session, err := pipelineAuthenticator.Authenticate( | ||
makeRequest("GET", "/", map[string]string{"sessionid": "zyx"}, ""), | ||
json.RawMessage(fmt.Sprintf(`{"check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
require.NoError(t, err, "%#v", errors.Cause(err)) | ||
assert.Equal(t, &AuthenticationSession{ | ||
Subject: "123", | ||
Extra: map[string]interface{}{"foo": "bar"}, | ||
}, session) | ||
}) | ||
|
||
t.Run("description=should pass through method, path, and headers to auth server", func(t *testing.T) { | ||
testServer, requestRecorder := makeServer(200, `{"subject": "123"}`) | ||
session, err := pipelineAuthenticator.Authenticate( | ||
makeRequest("PUT", "/users/123?query=string", map[string]string{"sessionid": "zyx"}, ""), | ||
json.RawMessage(fmt.Sprintf(`{"check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
require.NoError(t, err, "%#v", errors.Cause(err)) | ||
assert.Len(t, requestRecorder.requests, 1) | ||
r := requestRecorder.requests[0] | ||
assert.Equal(t, r.Method, "PUT") | ||
assert.Equal(t, r.URL.Path, "/users/123?query=string") | ||
assert.Equal(t, r.Header.Get("Cookie"), "sessionid=zyx") | ||
assert.Equal(t, &AuthenticationSession{Subject: "123"}, session) | ||
}) | ||
|
||
t.Run("description=does not pass request body through to auth server", func(t *testing.T) { | ||
testServer, requestRecorder := makeServer(200, `{}`) | ||
pipelineAuthenticator.Authenticate( | ||
makeRequest("POST", "/", map[string]string{"sessionid": "zyx"}, "Some body..."), | ||
json.RawMessage(fmt.Sprintf(`{"check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
assert.Len(t, requestRecorder.requests, 1) | ||
assert.Len(t, requestRecorder.bodies, 1) | ||
r := requestRecorder.requests[0] | ||
assert.Equal(t, r.ContentLength, int64(0)) | ||
assert.Equal(t, requestRecorder.bodies[0], []byte{}) | ||
}) | ||
|
||
t.Run("description=should fallthrough if only is specified and no cookie specified is set", func(t *testing.T) { | ||
testServer, requestRecorder := makeServer(200, `{}`) | ||
_, err := pipelineAuthenticator.Authenticate( | ||
makeRequest("GET", "/", map[string]string{"sessionid": "zyx"}, ""), | ||
json.RawMessage(fmt.Sprintf(`{"only": ["session", "sid"], "check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
assert.Equal(t, errors.Cause(err), ErrAuthenticatorNotResponsible) | ||
assert.Empty(t, requestRecorder.requests) | ||
}) | ||
|
||
t.Run("description=should not fallthrough if only is specified and cookie specified is set", func(t *testing.T) { | ||
testServer, _ := makeServer(200, `{}`) | ||
_, err := pipelineAuthenticator.Authenticate( | ||
makeRequest("GET", "/", map[string]string{"sid": "zyx"}, ""), | ||
json.RawMessage(fmt.Sprintf(`{"only": ["session", "sid"], "check_session_url": "%s"}`, testServer.URL)), | ||
nil, | ||
) | ||
require.NoError(t, err, "%#v", errors.Cause(err)) | ||
}) | ||
|
||
}) | ||
} | ||
|
||
type RequestRecorder struct { | ||
requests []*http.Request | ||
bodies [][]byte | ||
} | ||
|
||
func makeServer(statusCode int, responseBody string) (*httptest.Server, *RequestRecorder) { | ||
requestRecorder := &RequestRecorder{} | ||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
requestRecorder.requests = append(requestRecorder.requests, r) | ||
requestBody, _ := ioutil.ReadAll(r.Body) | ||
requestRecorder.bodies = append(requestRecorder.bodies, requestBody) | ||
w.WriteHeader(statusCode) | ||
w.Write([]byte(responseBody)) | ||
})) | ||
return testServer, requestRecorder | ||
} | ||
|
||
func makeRequest(method string, path string, cookies map[string]string, bodyStr string) *http.Request { | ||
var body io.ReadCloser | ||
header := http.Header{} | ||
if bodyStr != "" { | ||
body = ioutil.NopCloser(bytes.NewBufferString(bodyStr)) | ||
header.Add("Content-Length", strconv.Itoa(len(bodyStr))) | ||
} | ||
req := &http.Request{ | ||
Method: method, | ||
URL: &url.URL{Path: path}, | ||
Header: header, | ||
Body: body, | ||
} | ||
for name, value := range cookies { | ||
req.AddCookie(&http.Cookie{Name: name, Value: value}) | ||
} | ||
return req | ||
} |