Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Traefik decision api support #904

Merged
merged 19 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions api/decision.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@ import (

const (
DecisionPath = "/decisions"

xForwardedMethod = "X-Forwarded-Method"
xForwardedProto = "X-Forwarded-Proto"
xForwardedHost = "X-Forwarded-Host"
xForwardedUri = "X-Forwarded-Uri"
)

type decisionHandlerRegistry interface {
x.RegistryWriter
x.RegistryLogger

RuleMatcher() rule.Matcher
ProxyRequestHandler() *proxy.RequestHandler
ProxyRequestHandler() proxy.RequestHandler
}

type DecisionHandler struct {
Expand All @@ -53,12 +58,11 @@ func NewJudgeHandler(r decisionHandlerRegistry) *DecisionHandler {

func (h *DecisionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if len(r.URL.Path) >= len(DecisionPath) && r.URL.Path[:len(DecisionPath)] == DecisionPath {
r.URL.Scheme = "http"
r.URL.Host = r.Host
if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") {
r.URL.Scheme = "https"
}
r.URL.Path = r.URL.Path[len(DecisionPath):]
r.Method = x.OrDefaultString(r.Header.Get(xForwardedMethod), r.Method)
r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto),
x.IfThenElseString(r.TLS != nil, "https", "http"))
r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.Host)
r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path[len(DecisionPath):])

h.decisions(w, r)
} else {
Expand Down Expand Up @@ -112,7 +116,6 @@ func (h *DecisionHandler) decisions(w http.ResponseWriter, r *http.Request) {
WithFields(fields).
WithField("granted", false).
Info("Access request denied")

h.r.ProxyRequestHandler().HandleError(w, r, rl, err)
return
}
Expand Down
144 changes: 137 additions & 7 deletions api/decision_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,30 @@ package api_test
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"

"github.com/ory/viper"

"github.com/urfave/negroni"

"github.com/ory/oathkeeper/driver/configuration"
"github.com/ory/oathkeeper/internal"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"

"github.com/ory/herodot"
"github.com/ory/oathkeeper/api"
"github.com/ory/oathkeeper/driver/configuration"
"github.com/ory/oathkeeper/internal"
"github.com/ory/oathkeeper/pipeline/authn"
"github.com/ory/oathkeeper/proxy"
"github.com/ory/oathkeeper/rule"
"github.com/ory/viper"
"github.com/ory/x/logrusx"
)

func TestDecisionAPI(t *testing.T) {
Expand Down Expand Up @@ -344,3 +349,128 @@ func TestDecisionAPI(t *testing.T) {
})
}
}

type decisionHandlerRegistryMock struct {
mock.Mock
}

func (m *decisionHandlerRegistryMock) RuleMatcher() rule.Matcher {
return m
}

func (m *decisionHandlerRegistryMock) ProxyRequestHandler() proxy.RequestHandler {
return m
}

func (*decisionHandlerRegistryMock) Writer() herodot.Writer {
return nil
}

func (*decisionHandlerRegistryMock) Logger() *logrusx.Logger {
return logrusx.New("", "")
}

func (m *decisionHandlerRegistryMock) Match(ctx context.Context, method string, u *url.URL) (*rule.Rule, error) {
args := m.Called(ctx, method, u)
return args.Get(0).(*rule.Rule), args.Error(1)
}

func (*decisionHandlerRegistryMock) HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) {
}

func (*decisionHandlerRegistryMock) HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) {
return &authn.AuthenticationSession{}, nil
}

func (*decisionHandlerRegistryMock) InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession {
return nil
}

func TestDecisionAPIHeaderUsage(t *testing.T) {
r := new(decisionHandlerRegistryMock)
h := api.NewJudgeHandler(r)
defaultUrl := &url.URL{Scheme: "http", Host: "ory.sh", Path: "/foo"}
defaultMethod := "GET"
defaultTransform := func(req *http.Request) {}

for _, tc := range []struct {
name string
expectedMethod string
expectedUrl *url.URL
transform func(req *http.Request)
}{
{
name: "all arguments are taken from the url and request method",
expectedUrl: defaultUrl,
expectedMethod: defaultMethod,
transform: defaultTransform,
},
{
name: "all arguments are taken from the url and request method, but scheme from URL TLS settings",
expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path},
expectedMethod: defaultMethod,
transform: func(req *http.Request) {
req.TLS = &tls.ConnectionState{}
},
},
{
name: "all arguments are taken from the headers",
expectedUrl: &url.URL{Scheme: "https", Host: "test.dev", Path: "/bar"},
expectedMethod: "POST",
transform: func(req *http.Request) {
req.Header.Add("X-Forwarded-Method", "POST")
req.Header.Add("X-Forwarded-Proto", "https")
req.Header.Add("X-Forwarded-Host", "test.dev")
req.Header.Add("X-Forwarded-Uri", "/bar")
},
},
{
name: "only scheme is taken from the headers",
expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path},
expectedMethod: defaultMethod,
transform: func(req *http.Request) {
req.Header.Add("X-Forwarded-Proto", "https")
},
},
{
name: "only method is taken from the headers",
expectedUrl: defaultUrl,
expectedMethod: "POST",
transform: func(req *http.Request) {
req.Header.Add("X-Forwarded-Method", "POST")
},
},
{
name: "only host is taken from the headers",
expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: "test.dev", Path: defaultUrl.Path},
expectedMethod: defaultMethod,
transform: func(req *http.Request) {
req.Header.Add("X-Forwarded-Host", "test.dev")
},
},
{
name: "only path is taken from the headers",
expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: defaultUrl.Host, Path: "/bar"},
expectedMethod: defaultMethod,
transform: func(req *http.Request) {
req.Header.Add("X-Forwarded-Uri", "/bar")
},
},
} {
t.Run(tc.name, func(t *testing.T) {
res := httptest.NewRecorder()
reqUrl := *defaultUrl
reqUrl.Path = api.DecisionPath + reqUrl.Path
req := httptest.NewRequest(defaultMethod, reqUrl.String(), nil)
tc.transform(req)

r.On("Match", mock.Anything,
mock.MatchedBy(func(val string) bool { return val == tc.expectedMethod }),
mock.MatchedBy(func(val *url.URL) bool { return *val == *tc.expectedUrl })).
Return(&rule.Rule{}, nil)
h.ServeHTTP(res, req, nil)

r.AssertExpectations(t)
})
}
}
14 changes: 8 additions & 6 deletions credentials/verifier_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/ecdsa"
"crypto/rsa"
"fmt"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
Expand Down Expand Up @@ -42,7 +43,7 @@ func (v *VerifierDefault) Verify(

kid, ok := token.Header["kid"].(string)
if !ok || kid == "" {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The JSON Web Token must contain a kid header value but did not."))
return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The JSON Web Token must contain a kid header value but did not."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing all of these is kind of a breaking change (although not an terrible one). Given that our service configuration returns something incorrect here, why would we report to the client calling this endpoint that it's a bad request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my understanding of the code, this is actually not about our service configuration, it is about the data sent by the client. If the client sent a malformed request (here the an access token not containing the expected data), it should imho be answered accordingly.

Another reason was logging. Internal server error is about unrecoverable errors, which imply some kind of a bug, which should be addressed. Here, we don't have things like this. So I wanted avoiding the be filled with error log statements, which are no errors.

Do you agree, or do you still want to have these changes reverted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in case that we verify credentials, Unauthorized error would also be acceptable. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea is not polluting the logs with error messages, which are actually not errors. I consider this as best practice.
Ah, if I remember correctly, to have the error handler working with traefik (to redirect to the login page), I had to change some of the responses to Unauthorized (which relate to authentication cases).

Back to your proposal. So you would rather like to see all changed responses to be Unauthorized? Fine for me ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let me know if this is what you would like to have. I'll then update the PR accordingly.

}

key, err := v.r.CredentialsFetcher().ResolveKey(ctx, r.KeyURLs, kid, "sig")
Expand Down Expand Up @@ -74,10 +75,10 @@ func (v *VerifierDefault) Verify(
return k, nil
}
default:
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"]))
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"]))
}

return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`))
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`))
})
if err != nil {
if e, ok := errors.Cause(err).(*jwt.ValidationError); ok {
Expand All @@ -100,13 +101,14 @@ func (v *VerifierDefault) Verify(
parsedClaims := jwtx.ParseMapStringInterfaceClaims(claims)
for _, audience := range r.Audiences {
if !stringslice.Has(parsedClaims.Audience, audience) {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience))
return nil, herodot.ErrUnauthorized.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience)
}
}

if len(r.Issuers) > 0 {
if !stringslice.Has(r.Issuers, parsedClaims.Issuer) {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Token issuer does not match any trusted issuer."))
return nil, herodot.ErrUnauthorized.WithReasonf("Token issuer does not match any trusted issuer %s.", parsedClaims.Issuer).
WithDetail("received issuers", strings.Join(r.Issuers, ", "))
}
}

Expand All @@ -117,7 +119,7 @@ func (v *VerifierDefault) Verify(
if r.ScopeStrategy != nil {
for _, sc := range r.Scope {
if !r.ScopeStrategy(s, sc) {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc))
return nil, herodot.ErrUnauthorized.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc)
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Registry interface {
BuildDate() string
BuildHash() string

ProxyRequestHandler() *proxy.RequestHandler
ProxyRequestHandler() proxy.RequestHandler
HealthEventManager() health.EventManager
HealthHandler() *healthx.Handler
RuleHandler() *api.RuleHandler
Expand Down
4 changes: 2 additions & 2 deletions driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type RegistryMemory struct {
apiJudgeHandler *api.DecisionHandler
healthxHandler *healthx.Handler

proxyRequestHandler *proxy.RequestHandler
proxyRequestHandler proxy.RequestHandler
proxyProxy *proxy.Proxy
ruleFetcher rule.Fetcher

Expand Down Expand Up @@ -89,7 +89,7 @@ func (r *RegistryMemory) WithRuleFetcher(fetcher rule.Fetcher) Registry {
return r
}

func (r *RegistryMemory) ProxyRequestHandler() *proxy.RequestHandler {
func (r *RegistryMemory) ProxyRequestHandler() proxy.RequestHandler {
if r.proxyRequestHandler == nil {
r.proxyRequestHandler = proxy.NewRequestHandler(r, r.c)
}
Expand Down
17 changes: 14 additions & 3 deletions pipeline/errors/error_redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (

var _ Handler = new(ErrorRedirect)

const (
xForwardedProto = "X-Forwarded-Proto"
xForwardedHost = "X-Forwarded-Host"
xForwardedUri = "X-Forwarded-Uri"
)

type (
ErrorRedirectConfig struct {
To string `json:"to"`
Expand Down Expand Up @@ -40,7 +46,11 @@ func (a *ErrorRedirect) Handle(w http.ResponseWriter, r *http.Request, config js
return err
}

http.Redirect(w, r, a.RedirectURL(r, c), c.Code)
r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto), r.URL.Scheme)
r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.URL.Host)
r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path)

http.Redirect(w, r, a.RedirectURL(r.URL, c), c.Code)
return nil
}

Expand Down Expand Up @@ -69,7 +79,7 @@ func (a *ErrorRedirect) GetID() string {
return "redirect"
}

func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) string {
func (a *ErrorRedirect) RedirectURL(uri *url.URL, c *ErrorRedirectConfig) string {
if c.ReturnToQueryParam == "" {
return c.To
}
Expand All @@ -78,8 +88,9 @@ func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) str
if err != nil {
return c.To
}

q := u.Query()
q.Set(c.ReturnToQueryParam, r.URL.String())
q.Set(c.ReturnToQueryParam, uri.String())
u.RawQuery = q.Encode()
return u.String()
}
Loading