diff --git a/api/decision.go b/api/decision.go index 2325ec91fe..5b0fab66cd 100644 --- a/api/decision.go +++ b/api/decision.go @@ -33,6 +33,11 @@ import ( const ( DecisionPath = "/decisions" + + xForwardedMethod = "X-Forwarded-Method" + xForwardedProto = "X-Forwarded-Proto" + xForwardedHost = "X-Forwarded-Host" + xForwardedUri = "X-Forwarded-Uri" ) type decisionHandlerRegistry interface { @@ -40,7 +45,7 @@ type decisionHandlerRegistry interface { x.RegistryLogger RuleMatcher() rule.Matcher - ProxyRequestHandler() *proxy.RequestHandler + ProxyRequestHandler() proxy.RequestHandler } type DecisionHandler struct { @@ -53,12 +58,11 @@ func NewJudgeHandler(r decisionHandlerRegistry) *DecisionHandler { func (h *DecisionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { if len(r.URL.Path) >= len(DecisionPath) && r.URL.Path[:len(DecisionPath)] == DecisionPath { - r.URL.Scheme = "http" - r.URL.Host = r.Host - if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") { - r.URL.Scheme = "https" - } - r.URL.Path = r.URL.Path[len(DecisionPath):] + r.Method = x.OrDefaultString(r.Header.Get(xForwardedMethod), r.Method) + r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto), + x.IfThenElseString(r.TLS != nil, "https", "http")) + r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.Host) + r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path[len(DecisionPath):]) h.decisions(w, r) } else { @@ -112,7 +116,6 @@ func (h *DecisionHandler) decisions(w http.ResponseWriter, r *http.Request) { WithFields(fields). WithField("granted", false). Info("Access request denied") - h.r.ProxyRequestHandler().HandleError(w, r, rl, err) return } diff --git a/api/decision_test.go b/api/decision_test.go index 06c3b19609..dcb44a574f 100644 --- a/api/decision_test.go +++ b/api/decision_test.go @@ -23,25 +23,30 @@ package api_test import ( "bytes" "context" + "crypto/tls" "fmt" "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strconv" "testing" - "github.com/ory/viper" - - "github.com/urfave/negroni" - - "github.com/ory/oathkeeper/driver/configuration" - "github.com/ory/oathkeeper/internal" - "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/urfave/negroni" + "github.com/ory/herodot" + "github.com/ory/oathkeeper/api" + "github.com/ory/oathkeeper/driver/configuration" + "github.com/ory/oathkeeper/internal" + "github.com/ory/oathkeeper/pipeline/authn" + "github.com/ory/oathkeeper/proxy" "github.com/ory/oathkeeper/rule" + "github.com/ory/viper" + "github.com/ory/x/logrusx" ) func TestDecisionAPI(t *testing.T) { @@ -344,3 +349,128 @@ func TestDecisionAPI(t *testing.T) { }) } } + +type decisionHandlerRegistryMock struct { + mock.Mock +} + +func (m *decisionHandlerRegistryMock) RuleMatcher() rule.Matcher { + return m +} + +func (m *decisionHandlerRegistryMock) ProxyRequestHandler() proxy.RequestHandler { + return m +} + +func (*decisionHandlerRegistryMock) Writer() herodot.Writer { + return nil +} + +func (*decisionHandlerRegistryMock) Logger() *logrusx.Logger { + return logrusx.New("", "") +} + +func (m *decisionHandlerRegistryMock) Match(ctx context.Context, method string, u *url.URL) (*rule.Rule, error) { + args := m.Called(ctx, method, u) + return args.Get(0).(*rule.Rule), args.Error(1) +} + +func (*decisionHandlerRegistryMock) HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) { +} + +func (*decisionHandlerRegistryMock) HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) { + return &authn.AuthenticationSession{}, nil +} + +func (*decisionHandlerRegistryMock) InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession { + return nil +} + +func TestDecisionAPIHeaderUsage(t *testing.T) { + r := new(decisionHandlerRegistryMock) + h := api.NewJudgeHandler(r) + defaultUrl := &url.URL{Scheme: "http", Host: "ory.sh", Path: "/foo"} + defaultMethod := "GET" + defaultTransform := func(req *http.Request) {} + + for _, tc := range []struct { + name string + expectedMethod string + expectedUrl *url.URL + transform func(req *http.Request) + }{ + { + name: "all arguments are taken from the url and request method", + expectedUrl: defaultUrl, + expectedMethod: defaultMethod, + transform: defaultTransform, + }, + { + name: "all arguments are taken from the url and request method, but scheme from URL TLS settings", + expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path}, + expectedMethod: defaultMethod, + transform: func(req *http.Request) { + req.TLS = &tls.ConnectionState{} + }, + }, + { + name: "all arguments are taken from the headers", + expectedUrl: &url.URL{Scheme: "https", Host: "test.dev", Path: "/bar"}, + expectedMethod: "POST", + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Method", "POST") + req.Header.Add("X-Forwarded-Proto", "https") + req.Header.Add("X-Forwarded-Host", "test.dev") + req.Header.Add("X-Forwarded-Uri", "/bar") + }, + }, + { + name: "only scheme is taken from the headers", + expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path}, + expectedMethod: defaultMethod, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Proto", "https") + }, + }, + { + name: "only method is taken from the headers", + expectedUrl: defaultUrl, + expectedMethod: "POST", + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Method", "POST") + }, + }, + { + name: "only host is taken from the headers", + expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: "test.dev", Path: defaultUrl.Path}, + expectedMethod: defaultMethod, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Host", "test.dev") + }, + }, + { + name: "only path is taken from the headers", + expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: defaultUrl.Host, Path: "/bar"}, + expectedMethod: defaultMethod, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Uri", "/bar") + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + res := httptest.NewRecorder() + reqUrl := *defaultUrl + reqUrl.Path = api.DecisionPath + reqUrl.Path + req := httptest.NewRequest(defaultMethod, reqUrl.String(), nil) + tc.transform(req) + + r.On("Match", mock.Anything, + mock.MatchedBy(func(val string) bool { return val == tc.expectedMethod }), + mock.MatchedBy(func(val *url.URL) bool { return *val == *tc.expectedUrl })). + Return(&rule.Rule{}, nil) + h.ServeHTTP(res, req, nil) + + r.AssertExpectations(t) + }) + } +} diff --git a/credentials/verifier_default.go b/credentials/verifier_default.go index b78c6c0f2c..e1ed43d07d 100644 --- a/credentials/verifier_default.go +++ b/credentials/verifier_default.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/rsa" "fmt" + "strings" "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" @@ -42,7 +43,7 @@ func (v *VerifierDefault) Verify( kid, ok := token.Header["kid"].(string) if !ok || kid == "" { - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The JSON Web Token must contain a kid header value but did not.")) + return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The JSON Web Token must contain a kid header value but did not.")) } key, err := v.r.CredentialsFetcher().ResolveKey(ctx, r.KeyURLs, kid, "sig") @@ -74,10 +75,10 @@ func (v *VerifierDefault) Verify( return k, nil } default: - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"])) + return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"])) } - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`)) + return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`)) }) if err != nil { if e, ok := errors.Cause(err).(*jwt.ValidationError); ok { @@ -100,13 +101,14 @@ func (v *VerifierDefault) Verify( parsedClaims := jwtx.ParseMapStringInterfaceClaims(claims) for _, audience := range r.Audiences { if !stringslice.Has(parsedClaims.Audience, audience) { - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience)) + return nil, herodot.ErrUnauthorized.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience) } } if len(r.Issuers) > 0 { if !stringslice.Has(r.Issuers, parsedClaims.Issuer) { - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Token issuer does not match any trusted issuer.")) + return nil, herodot.ErrUnauthorized.WithReasonf("Token issuer does not match any trusted issuer %s.", parsedClaims.Issuer). + WithDetail("received issuers", strings.Join(r.Issuers, ", ")) } } @@ -117,7 +119,7 @@ func (v *VerifierDefault) Verify( if r.ScopeStrategy != nil { for _, sc := range r.Scope { if !r.ScopeStrategy(s, sc) { - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc)) + return nil, herodot.ErrUnauthorized.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc) } } } else { diff --git a/driver/registry.go b/driver/registry.go index 14c62daabd..11a8ed7630 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -30,7 +30,7 @@ type Registry interface { BuildDate() string BuildHash() string - ProxyRequestHandler() *proxy.RequestHandler + ProxyRequestHandler() proxy.RequestHandler HealthEventManager() health.EventManager HealthHandler() *healthx.Handler RuleHandler() *api.RuleHandler diff --git a/driver/registry_memory.go b/driver/registry_memory.go index 69ace4247e..fda492c768 100644 --- a/driver/registry_memory.go +++ b/driver/registry_memory.go @@ -53,7 +53,7 @@ type RegistryMemory struct { apiJudgeHandler *api.DecisionHandler healthxHandler *healthx.Handler - proxyRequestHandler *proxy.RequestHandler + proxyRequestHandler proxy.RequestHandler proxyProxy *proxy.Proxy ruleFetcher rule.Fetcher @@ -89,7 +89,7 @@ func (r *RegistryMemory) WithRuleFetcher(fetcher rule.Fetcher) Registry { return r } -func (r *RegistryMemory) ProxyRequestHandler() *proxy.RequestHandler { +func (r *RegistryMemory) ProxyRequestHandler() proxy.RequestHandler { if r.proxyRequestHandler == nil { r.proxyRequestHandler = proxy.NewRequestHandler(r, r.c) } diff --git a/pipeline/errors/error_redirect.go b/pipeline/errors/error_redirect.go index 4d90498d0d..9235df1659 100644 --- a/pipeline/errors/error_redirect.go +++ b/pipeline/errors/error_redirect.go @@ -12,6 +12,12 @@ import ( var _ Handler = new(ErrorRedirect) +const ( + xForwardedProto = "X-Forwarded-Proto" + xForwardedHost = "X-Forwarded-Host" + xForwardedUri = "X-Forwarded-Uri" +) + type ( ErrorRedirectConfig struct { To string `json:"to"` @@ -40,7 +46,11 @@ func (a *ErrorRedirect) Handle(w http.ResponseWriter, r *http.Request, config js return err } - http.Redirect(w, r, a.RedirectURL(r, c), c.Code) + r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto), r.URL.Scheme) + r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.URL.Host) + r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path) + + http.Redirect(w, r, a.RedirectURL(r.URL, c), c.Code) return nil } @@ -69,7 +79,7 @@ func (a *ErrorRedirect) GetID() string { return "redirect" } -func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) string { +func (a *ErrorRedirect) RedirectURL(uri *url.URL, c *ErrorRedirectConfig) string { if c.ReturnToQueryParam == "" { return c.To } @@ -78,8 +88,9 @@ func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) str if err != nil { return c.To } + q := u.Query() - q.Set(c.ReturnToQueryParam, r.URL.String()) + q.Set(c.ReturnToQueryParam, uri.String()) u.RawQuery = q.Encode() return u.String() } diff --git a/pipeline/errors/error_redirect_test.go b/pipeline/errors/error_redirect_test.go index f6e929095f..802ecd83d7 100644 --- a/pipeline/errors/error_redirect_test.go +++ b/pipeline/errors/error_redirect_test.go @@ -178,3 +178,81 @@ func TestErrorRedirect(t *testing.T) { } }) } + +func TestErrorReturnToRedirectURLHeaderUsage(t *testing.T) { + conf := internal.NewConfigurationWithDefaults() + reg := internal.NewRegistry(conf) + + defaultUrl := &url.URL{Scheme: "http", Host: "ory.sh", Path: "/foo"} + defaultTransform := func(req *http.Request) {} + config := `{"to":"http://test/test","return_to_query_param":"return_to"}` + + a, err := reg.PipelineErrorHandler("redirect") + require.NoError(t, err) + assert.Equal(t, "redirect", a.GetID()) + + for _, tc := range []struct { + name string + expectedUrl *url.URL + transform func(req *http.Request) + }{ + { + name: "all arguments are taken from the url and request method", + expectedUrl: defaultUrl, + transform: defaultTransform, + }, + { + name: "all arguments are taken from the headers", + expectedUrl: &url.URL{Scheme: "https", Host: "test.dev", Path: "/bar"}, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Proto", "https") + req.Header.Add("X-Forwarded-Host", "test.dev") + req.Header.Add("X-Forwarded-Uri", "/bar") + }, + }, + { + name: "only scheme is taken from the headers", + expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path}, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Proto", "https") + }, + }, + { + name: "only host is taken from the headers", + expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: "test.dev", Path: defaultUrl.Path}, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Host", "test.dev") + }, + }, + { + name: "only path is taken from the headers", + expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: defaultUrl.Host, Path: "/bar"}, + transform: func(req *http.Request) { + req.Header.Add("X-Forwarded-Uri", "/bar") + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", defaultUrl.String(), nil) + tc.transform(r) + + err = a.Handle(w, r, json.RawMessage(config), nil, nil) + assert.NoError(t, err) + + loc := w.Header().Get("Location") + assert.NotEmpty(t, loc) + + locUrl, err := url.Parse(loc) + assert.NoError(t, err) + + returnTo := locUrl.Query().Get("return_to") + assert.NotEmpty(t, returnTo) + + returnToUrl, err := url.Parse(returnTo) + assert.NoError(t, err) + + assert.Equal(t, tc.expectedUrl, returnToUrl) + }) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index cca60496fb..2d9c65bdf3 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -39,7 +39,7 @@ type proxyRegistry interface { x.RegistryLogger x.RegistryWriter - ProxyRequestHandler() *RequestHandler + ProxyRequestHandler() RequestHandler RuleMatcher() rule.Matcher } diff --git a/proxy/request_handler.go b/proxy/request_handler.go index 89ff378b86..b9fbd3b20b 100644 --- a/proxy/request_handler.go +++ b/proxy/request_handler.go @@ -51,7 +51,13 @@ type requestHandlerRegistry interface { pe.Registry } -type RequestHandler struct { +type RequestHandler interface { + HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) + HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) + InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession +} + +type requestHandler struct { r requestHandlerRegistry c configuration.Provider } @@ -60,12 +66,12 @@ type whenConfig struct { When pe.Whens `json:"when"` } -func NewRequestHandler(r requestHandlerRegistry, c configuration.Provider) *RequestHandler { - return &RequestHandler{r: r, c: c} +func NewRequestHandler(r requestHandlerRegistry, c configuration.Provider) RequestHandler { + return &requestHandler{r: r, c: c} } // matchesWhen -func (d *RequestHandler) matchesWhen(w http.ResponseWriter, r *http.Request, h pe.Handler, config json.RawMessage, handleErr error) error { +func (d *requestHandler) matchesWhen(w http.ResponseWriter, r *http.Request, h pe.Handler, config json.RawMessage, handleErr error) error { var when whenConfig if err := d.c.ErrorHandlerConfig(h.GetID(), config, &when); err != nil { d.r.Writer().WriteError(w, r, pe.NewErrErrorHandlerMisconfigured(h, err)) @@ -83,7 +89,7 @@ func (d *RequestHandler) matchesWhen(w http.ResponseWriter, r *http.Request, h p return nil } -func (d *RequestHandler) HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) { +func (d *requestHandler) HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) { if rl == nil { // Create a new, empty rule. rl = new(rule.Rule) @@ -167,7 +173,7 @@ func (d *RequestHandler) HandleError(w http.ResponseWriter, r *http.Request, rl } } -func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) { +func (d *requestHandler) HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) { var found bool fields := map[string]interface{}{ @@ -333,7 +339,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) (session } // InitializeAuthnSession creates an authentication session and initializes it with a Match context if possible -func (d *RequestHandler) InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession { +func (d *requestHandler) InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession { session := &authn.AuthenticationSession{ Subject: "", diff --git a/x/compare.go b/x/compare.go new file mode 100644 index 0000000000..97c29d3422 --- /dev/null +++ b/x/compare.go @@ -0,0 +1,15 @@ +package x + +func OrDefaultString(val, defaultVal string) string { + if val == "" { + return defaultVal + } + return val +} + +func IfThenElseString(c bool, thenVal, elseVal string) string { + if c { + return thenVal + } + return elseVal +}