diff --git a/.schema/config.schema.json b/.schema/config.schema.json index fa4338a4fc..af6141cf0c 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -1212,6 +1212,28 @@ "timeout": { "$ref": "#/definitions/serverTimeout" }, + "upstream": { + "type": "object", + "title": "HTTP Upstream", + "additionalProperties": false, + "properties": { + "ca_append_crt_path": { + "type": "string", + "default": "", + "examples": [ + "./self-signed.crt" + ], + "title": "CA Certificate", + "description": "The file containing the CA certificates to append to the Root CA when using upstream connections." + }, + "ca_refresh_frequency": { + "type": "integer", + "default": 1000, + "title": "CA Refresh Frequency", + "description": "Test if upstream transport needs certificate refresh once every 1000 requests." + } + } + }, "cors": { "$ref": "#/definitions/cors" }, diff --git a/docs/docs/api-access-rules.md b/docs/docs/api-access-rules.md index 2f4695ff76..8a988fd1a4 100644 --- a/docs/docs/api-access-rules.md +++ b/docs/docs/api-access-rules.md @@ -234,6 +234,21 @@ authenticators: } ``` +## Upstream Configuration + +Upstreams are often protected by SSL certificates trusted by the SSL root servers, but what if you want to run your upstream infrastructure using self-signed certificates? Use the `ca_append_crt_path` to specifiy the certificate file to append to the Root Certificate Authority (CA) of the upstream transport. Upstream transport is cached to reduce IO overhead at scale. The upstream transport certificate will automatically update if the certificate file path is changed. + +Use the `ca_refresh_frequency`to test once every `1000` requests for a certificate file size or modification timestamp change. Setting the value to 0 will disable the refresh on a frequency functionality, but retrain the update on file path change functionality. + +**oathkeeper.yml** +```yaml +serve: + proxy: + upstream: + ca_append_crt_path: "/my-certs/certs-to-append.crt", + ca_refresh_frequency: 1000 +``` + ## Scoped Credentials Some credentials are scoped. For example, OAuth 2.0 Access Tokens usually are diff --git a/docs/docs/reference/configuration.md b/docs/docs/reference/configuration.md index 393c8da233..e0df655e12 100644 --- a/docs/docs/reference/configuration.md +++ b/docs/docs/reference/configuration.md @@ -1538,6 +1538,46 @@ serve: # read: 5s + ## HTTP Upstream ## + # + # Control the HTTP upstream. + # + upstream: + ## Append Certificate To Root CA ## + # + # The path to a certificate file to append to the Root Certificate Authority for the upstream connection. Use this to accept self-signed certificates on the upstream only, keeping the host system certificate authority unaltered. + # + # Default value: "" + # + # Examples: + # - self-signed.crt + # + # Set this value using environment variables on + # - Linux/macOS: + # $ export SERVE_PROXY_UPSTREAM_CA_APPEND_CRT_PATH= + # - Windows Command Line (CMD): + # > set SERVE_PROXY_UPSTREAM_CA_APPEND_CRT_PATH= + # + ca_append_crt_path: "" + + ## CA Refresh Frequency ## + # + # The interval at which to test if the upstream transport certificate changed. Use 0 to disable test. + # + # Default value: 1000 + # + # Examples: + # - 1000 + # - 0 + # + # Set this value using environment variables on + # - Linux/macOS: + # $ export SERVE_PROXY_UPSTREAM_CA_REFRESH_FREQUENCY= + # - Windows Command Line (CMD): + # > set SERVE_PROXY_UPSTREAM_CA_REFRESH_FREQUENCY= + # + ca_refresh_frequency: 1000 + ## Cross Origin Resource Sharing (CORS) ## # # Configure [Cross Origin Resource Sharing (CORS)](http://www.w3.org/TR/cors/) using the following options. diff --git a/driver/configuration/provider.go b/driver/configuration/provider.go index 7812f3b6c5..e29a7fb490 100644 --- a/driver/configuration/provider.go +++ b/driver/configuration/provider.go @@ -5,7 +5,7 @@ import ( "net/url" "time" - "github.com/gobuffalo/packr/v2" + packr "github.com/gobuffalo/packr/v2" "github.com/ory/fosite" "github.com/ory/x/tracing" @@ -40,6 +40,8 @@ type Provider interface { ProxyReadTimeout() time.Duration ProxyWriteTimeout() time.Duration ProxyIdleTimeout() time.Duration + ProxyServeUpstreamCaAppendCrtPath() string + ProxyServeUpstreamCaRefreshFrequency() int APIReadTimeout() time.Duration APIWriteTimeout() time.Duration diff --git a/driver/configuration/provider_viper.go b/driver/configuration/provider_viper.go index 4a68415825..dfce6c25ea 100644 --- a/driver/configuration/provider_viper.go +++ b/driver/configuration/provider_viper.go @@ -43,6 +43,8 @@ const ( ViperKeyProxyIdleTimeout = "serve.proxy.timeout.idle" ViperKeyProxyServeAddressHost = "serve.proxy.host" ViperKeyProxyServeAddressPort = "serve.proxy.port" + ViperKeyProxyUpstreamCaAppendCrtPath = "serve.proxy.upstream.ca_append_crt_path" + ViperKeyProxyUpstreamCaRefreshFrequency = "serve.proxy.upstream.ca_refresh_frequency" ViperKeyAPIServeAddressHost = "serve.api.host" ViperKeyAPIServeAddressPort = "serve.api.port" ViperKeyAPIReadTimeout = "serve.api.timeout.read" @@ -179,6 +181,14 @@ func (v *ViperProvider) ProxyServeAddress() string { ) } +func (v *ViperProvider) ProxyServeUpstreamCaAppendCrtPath() string { + return viperx.GetString(v.l, ViperKeyProxyUpstreamCaAppendCrtPath, "") +} + +func (v *ViperProvider) ProxyServeUpstreamCaRefreshFrequency() int { + return viperx.GetInt(v.l, ViperKeyProxyUpstreamCaRefreshFrequency, 1000) +} + func (v *ViperProvider) APIReadTimeout() time.Duration { return viperx.GetDuration(v.l, ViperKeyAPIReadTimeout, time.Second*5) } diff --git a/driver/registry_memory.go b/driver/registry_memory.go index 62b1fe7e8d..0ce8603601 100644 --- a/driver/registry_memory.go +++ b/driver/registry_memory.go @@ -2,6 +2,11 @@ package driver import ( "context" + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net/http" + "os" "sync" "github.com/ory/oathkeeper/driver/health" @@ -57,6 +62,11 @@ type RegistryMemory struct { proxyProxy *proxy.Proxy ruleFetcher rule.Fetcher + cachedCertFilePath string + cachedCertFileStat *os.FileInfo + upstreamRequestCount int64 + upstreamTransport *http.Transport + authenticators map[string]authn.Authenticator authorizers map[string]authz.Authorizer mutators map[string]mutate.Mutator @@ -101,6 +111,108 @@ func (r *RegistryMemory) RuleMatcher() rule.Matcher { return r.ruleRepository } +func (r *RegistryMemory) isCachedUpstreamTransportDirty(certFile string) (bool, error) { + + // No cached transport yet, but configuration requires it, so handle as if it was dirty so we get one. + if r.upstreamTransport == nil { + return true, nil + } + + // The transport is dirty if the cert file path changed. + if certFile != r.cachedCertFilePath { + return true, nil + } + + // Transport is dirty if the certificate bytes changed or the modification time, + // but only check this based on the ca_refresh_frequency option to reduce IO impact. + caRefreshFrequency := r.c.ProxyServeUpstreamCaRefreshFrequency() + if caRefreshFrequency <= 0 { + return false, nil // Periodic refresh of the Ca is diabled. + } + + if r.upstreamRequestCount < int64(caRefreshFrequency) { + return false, nil + } + r.upstreamRequestCount = 0 + + stat, err := os.Stat(certFile) + if err != nil { + return false, err + } + + if stat.Size() != (*r.cachedCertFileStat).Size() || stat.ModTime() != (*r.cachedCertFileStat).ModTime() { + return true, nil + } + + return false, nil +} + +func createUpstreamTransportWithCertificate(certFile string) (*http.Transport, *os.FileInfo, error) { + transport := &(*http.DefaultTransport.(*http.Transport)) // shallow copy + + // Get the SystemCertPool or continue with an empty pool on error + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, nil, err + } + + stat, err := os.Stat(certFile) + if err != nil { + return nil, nil, err + } + + certs, err := ioutil.ReadFile(certFile) + if err != nil { + return nil, nil, err + } + + // Append our cert to the system pool + if ok := rootCAs.AppendCertsFromPEM(certs); !ok { + return nil, nil, errors.New("No certs appended, only system certs present, did you specify the correct cert file?") + } + + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: false, + RootCAs: rootCAs, + } + + return transport, &stat, nil +} + +// UpstreamTransport decides the transport to use for the upstream for the request. +func (r *RegistryMemory) UpstreamTransport(req *http.Request) (http.RoundTripper, error) { + + r.upstreamRequestCount++ + + // Use req to decide the transport per request iff need be. + + // Not using custom certificates on upstreams, so go with default transport no need to use cache + certFile := r.c.ProxyServeUpstreamCaAppendCrtPath() + if certFile == "" { + return http.DefaultTransport, nil + } + + // Decide wether to create a transport or use the cached one. Update the cached transport if it is dirty. + isTransportDirty, err := r.isCachedUpstreamTransportDirty(certFile) + if err != nil { + return nil, err + } + + if isTransportDirty == true { + + transport, stat, err := createUpstreamTransportWithCertificate(certFile) + if err != nil { + return nil, err + } + // Cache the transport and cert meta data + r.cachedCertFilePath = certFile + r.cachedCertFileStat = stat + r.upstreamTransport = transport + } + + return r.upstreamTransport, nil +} + func NewRegistryMemory() *RegistryMemory { return &RegistryMemory{} } diff --git a/proxy/proxy.go b/proxy/proxy.go index a3ed07963c..8e209ab3e0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -41,6 +41,7 @@ type proxyRegistry interface { ProxyRequestHandler() *RequestHandler RuleMatcher() rule.Matcher + UpstreamTransport(r *http.Request) (http.RoundTripper, error) } func NewProxy(r proxyRegistry) *Proxy { @@ -88,7 +89,18 @@ func (d *Proxy) RoundTrip(r *http.Request) (*http.Response, error) { Header: rw.header, }, nil } else if err == nil { - res, err := http.DefaultTransport.RoundTrip(r) + + transport, err := d.r.UpstreamTransport(r) + if err != nil { + d.r.Logger(). + WithError(errors.WithStack(err)). + WithField("granted", false). + WithFields(fields). + Warn("Access request denied because upstream transport creation failed") + return nil, err + } + + res, err := transport.RoundTrip(r) if err != nil { d.r.Logger(). WithError(errors.WithStack(err)). diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 90b4ca3248..de2758acfd 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -22,12 +22,14 @@ package proxy_test import ( "context" + "encoding/pem" "fmt" "io/ioutil" "net/http" "net/http/httptest" "net/http/httputil" "net/url" + "os" "strings" "testing" @@ -429,6 +431,61 @@ func TestConfigureBackendURL(t *testing.T) { } } +func TestUpstreamAppendCaCertToRootCa(t *testing.T) { + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello from TLS Upstream") + })) + defer backend.Close() + + crtBytes := backend.Certificate().Raw + + tmpfile, err := ioutil.TempFile("", "certificate.crt") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + defer tmpfile.Close() + + err = pem.Encode(tmpfile, &pem.Block{Type: "CERTIFICATE", Bytes: crtBytes}) + require.NoError(t, err) + + conf := internal.NewConfigurationWithDefaults() + reg := internal.NewRegistry(conf) + + d := reg.Proxy() + ts := httptest.NewServer(&httputil.ReverseProxy{Director: d.Director, Transport: d}) + defer ts.Close() + + viper.Set(configuration.ViperKeyAuthenticatorNoopIsEnabled, true) + viper.Set(configuration.ViperKeyAuthorizerAllowIsEnabled, true) + viper.Set(configuration.ViperKeyMutatorNoopIsEnabled, true) + viper.Set(configuration.ViperKeyProxyUpstreamCaAppendCrtPath, tmpfile.Name()) + + url := ts.URL + "/test-tls" + + reg.RuleRepository().(*rule.RepositoryMemory).WithRules([]rule.Rule{ + rule.Rule{ + Match: &rule.Match{Methods: []string{"GET"}, URL: url}, + Authenticators: []rule.Handler{{Handler: "noop"}}, + Authorizer: rule.Handler{Handler: "allow"}, + Mutators: []rule.Handler{{Handler: "noop"}}, + Upstream: rule.Upstream{URL: backend.URL}, + }, + }) + + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + response, err := ioutil.ReadAll(res.Body) + require.NoError(t, res.Body.Close()) + require.NoError(t, err) + + assert.True(t, strings.Contains(string(response), "Hello from TLS Upstream")) + + assert.Equal(t, 200, res.StatusCode, "%s", res.Body) +} + // // func BenchmarkDirector(b *testing.B) { // backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {