diff --git a/command/server.go b/command/server.go index 2a240dfcf1ba..ede3611661d7 100644 --- a/command/server.go +++ b/command/server.go @@ -97,7 +97,8 @@ type ServerCommand struct { type ServerListener struct { net.Listener - config map[string]interface{} + config map[string]interface{} + maxRequestSize int64 } func (c *ServerCommand) Synopsis() string { @@ -689,11 +690,6 @@ CLUSTER_SYNTHESIS_COMPLETE: return 1 } - lns = append(lns, ServerListener{ - Listener: ln, - config: lnConfig.Config, - }) - if reloadFunc != nil { relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type] relSlice = append(relSlice, reloadFunc) @@ -728,6 +724,26 @@ CLUSTER_SYNTHESIS_COMPLETE: props["cluster address"] = addr } + var maxRequestSize int64 = vaulthttp.DefaultMaxRequestSize + if valRaw, ok := lnConfig.Config["max_request_size"]; ok { + val, err := parseutil.ParseInt(valRaw) + if err != nil { + c.UI.Error(fmt.Sprintf("Could not parse max_request_size value %v", valRaw)) + return 1 + } + + if val >= 0 { + maxRequestSize = val + } + } + props["max_request_size"] = fmt.Sprintf("%d", maxRequestSize) + + lns = append(lns, ServerListener{ + Listener: ln, + config: lnConfig.Config, + maxRequestSize: maxRequestSize, + }) + // Store the listener props for output later key := fmt.Sprintf("listener %d", i+1) propsList := make([]string, 0, len(props)) @@ -792,7 +808,9 @@ CLUSTER_SYNTHESIS_COMPLETE: // This needs to happen before we first unseal, so before we trigger dev // mode if it's set core.SetClusterListenerAddrs(clusterAddrs) - core.SetClusterHandler(vaulthttp.Handler(core)) + core.SetClusterHandler(vaulthttp.Handler(&vault.HandlerProperties{ + Core: core, + })) err = core.UnsealWithStoredKeys(context.Background()) if err != nil { @@ -925,7 +943,10 @@ CLUSTER_SYNTHESIS_COMPLETE: // Initialize the HTTP servers for _, ln := range lns { - handler := vaulthttp.Handler(core) + handler := vaulthttp.Handler(&vault.HandlerProperties{ + Core: core, + MaxRequestSize: ln.maxRequestSize, + }) // We perform validation on the config earlier, we can just cast here if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok { @@ -1195,7 +1216,9 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m c.UI.Output("") for _, core := range testCluster.Cores { - core.Server.Handler = vaulthttp.Handler(core.Core) + core.Server.Handler = vaulthttp.Handler(&vault.HandlerProperties{ + Core: core.Core, + }) core.SetClusterHandler(core.Server.Handler) } diff --git a/command/server/config.go b/command/server/config.go index 33c98db4a6d3..7a5212aa999e 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -804,6 +804,7 @@ func parseListeners(result *Config, list *ast.ObjectList) error { "x_forwarded_for_reject_not_authorized", "x_forwarded_for_reject_not_present", "infrastructure", + "max_request_size", "node_id", "proxy_protocol_behavior", "proxy_protocol_authorized_addrs", diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index 92e6cb152426..0a4973e9f84e 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -4,6 +4,9 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "errors" + "io" + "io/ioutil" "net/http" "net/url" "os" @@ -56,11 +59,30 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request } func GenerateForwardedRequest(req *http.Request) (*Request, error) { + var reader io.Reader = req.Body + ctx := req.Context() + maxRequestSize := ctx.Value("max_request_size") + if maxRequestSize != nil { + max, ok := maxRequestSize.(int64) + if !ok { + return nil, errors.New("could not parse max_request_size from request context") + } + if max > 0 { + reader = io.LimitReader(req.Body, max) + } + } + + body, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + fq := Request{ Method: req.Method, HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)), Host: req.Host, RemoteAddr: req.RemoteAddr, + Body: body, } reqURL := req.URL @@ -80,13 +102,6 @@ func GenerateForwardedRequest(req *http.Request) (*Request, error) { } } - buf := bytes.NewBuffer(nil) - _, err := buf.ReadFrom(req.Body) - if err != nil { - return nil, err - } - fq.Body = buf.Bytes() - if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 { fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates)) for i, cert := range req.TLS.PeerCertificates { diff --git a/http/forwarded_for_test.go b/http/forwarded_for_test.go index 0eec439f4adc..170b54334dc2 100644 --- a/http/forwarded_for_test.go +++ b/http/forwarded_for_test.go @@ -24,7 +24,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // First: test reject not present t.Run("reject_not_present", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -69,7 +69,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test allow unauth t.Run("allow_unauth", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -106,7 +106,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test fail unauth t.Run("fail_unauth", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -140,7 +140,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test bad hops (too many) t.Run("too_many_hops", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -174,7 +174,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test picking correct value t.Run("correct_hop_skipping", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -211,7 +211,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: multi-header approach t.Run("correct_hop_skipping_multi_header", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) diff --git a/http/handler.go b/http/handler.go index a9be673cb675..6cfd4a7b99e2 100644 --- a/http/handler.go +++ b/http/handler.go @@ -1,7 +1,9 @@ package http import ( + "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -52,10 +54,11 @@ const ( // soft-mandatory Sentinel policies. PolicyOverrideHeaderName = "X-Vault-Policy-Override" - // MaxRequestSize is the maximum accepted request size. This is to prevent - // a denial of service attack where no Content-Length is provided and the server - // is fed ever more data until it exhausts memory. - MaxRequestSize = 32 * 1024 * 1024 + // DefaultMaxRequestSize is the default maximum accepted request size. This + // is to prevent a denial of service attack where no Content-Length is + // provided and the server is fed ever more data until it exhausts memory. + // Can be overridden per listener. + DefaultMaxRequestSize = 32 * 1024 * 1024 ) var ( @@ -67,7 +70,9 @@ var ( // Handler returns an http.Handler for the API. This can be used on // its own to mount the Vault API within another web server. -func Handler(core *vault.Core) http.Handler { +func Handler(props *vault.HandlerProperties) http.Handler { + core := props.Core + // Create the muxer to handle the actual endpoints mux := http.NewServeMux() mux.Handle("/v1/sys/init", handleSysInit(core)) @@ -108,7 +113,7 @@ func Handler(core *vault.Core) http.Handler { // Wrap the help wrapped handler with another layer with a generic // handler - genericWrappedHandler := wrapGenericHandler(corsWrappedHandler) + genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize) // Wrap the handler with PrintablePathCheckHandler to check for non-printable // characters in the request path. @@ -120,12 +125,20 @@ func Handler(core *vault.Core) http.Handler { // wrapGenericHandler wraps the handler with an extra layer of handler where // tasks that should be commonly handled for all the requests and/or responses // are performed. -func wrapGenericHandler(h http.Handler) http.Handler { +func wrapGenericHandler(h http.Handler, maxRequestSize int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") - h.ServeHTTP(w, r) + + // Add a context and put the request limit for this handler in it + if maxRequestSize > 0 { + ctx := context.WithValue(r.Context(), "max_request_size", maxRequestSize) + h.ServeHTTP(w, r.WithContext(ctx)) + } else { + h.ServeHTTP(w, r) + } + return }) } @@ -326,8 +339,19 @@ func (fs *UIAssetWrapper) Open(name string) (http.File, error) { func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error { // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. - limit := http.MaxBytesReader(w, r.Body, MaxRequestSize) - err := jsonutil.DecodeJSONFromReader(limit, out) + reader := r.Body + ctx := r.Context() + maxRequestSize := ctx.Value("max_request_size") + if maxRequestSize != nil { + max, ok := maxRequestSize.(int64) + if !ok { + return errors.New("could not parse max_request_size from request context") + } + if max > 0 { + reader = http.MaxBytesReader(w, r.Body, max) + } + } + err := jsonutil.DecodeJSONFromReader(reader, out) if err != nil && err != io.EOF { return errwrap.Wrapf("failed to parse JSON input: {{err}}", err) } diff --git a/http/logical_test.go b/http/logical_test.go index e6ec3da29374..cd868bcfad92 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -261,7 +261,7 @@ func TestLogical_RequestSizeLimit(t *testing.T) { // Write a very large object, should fail resp := testHttpPut(t, token, addr+"/v1/secret/foo", map[string]interface{}{ - "data": make([]byte, MaxRequestSize), + "data": make([]byte, DefaultMaxRequestSize), }) testResponseStatus(t, resp, 413) } diff --git a/http/testing.go b/http/testing.go index 2299006c98bf..13501f5daf19 100644 --- a/http/testing.go +++ b/http/testing.go @@ -30,7 +30,10 @@ func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *v // for tests. mux := http.NewServeMux() mux.Handle("/_test/auth", http.HandlerFunc(testHandleAuth)) - mux.Handle("/", Handler(core)) + mux.Handle("/", Handler(&vault.HandlerProperties{ + Core: core, + MaxRequestSize: DefaultMaxRequestSize, + })) server := &http.Server{ Addr: ln.Addr().String(), diff --git a/vault/request_handling.go b/vault/request_handling.go index fd91e33dbc85..a6424b362673 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -26,6 +26,13 @@ const ( replTimeout = 10 * time.Second ) +// HanlderProperties is used to seed configuration into a vaulthttp.Handler. +// It's in this package to avoid a circular dependency +type HandlerProperties struct { + Core *Core + MaxRequestSize int64 +} + // fetchEntityAndDerivedPolicies returns the entity object for the given entity // ID. If the entity is merged into a different entity object, the entity into // which the given entity ID is merged into will be returned. This function diff --git a/vault/testing.go b/vault/testing.go index da8985e81fc9..99f0d65a5b7b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -880,7 +880,7 @@ type TestClusterCore struct { type TestClusterOptions struct { KeepStandbysSealed bool SkipInit bool - HandlerFunc func(*Core) http.Handler + HandlerFunc func(*HandlerProperties) http.Handler BaseListenAddress string NumCores int SealFunc func() Seal @@ -1249,7 +1249,9 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } cores = append(cores, c) if opts != nil && opts.HandlerFunc != nil { - handlers[i] = opts.HandlerFunc(c) + handlers[i] = opts.HandlerFunc(&HandlerProperties{ + Core: c, + }) servers[i].Handler = handlers[i] } } diff --git a/website/source/docs/configuration/listener/tcp.html.md b/website/source/docs/configuration/listener/tcp.html.md index e8a1c3fff718..28eefbbe08cf 100644 --- a/website/source/docs/configuration/listener/tcp.html.md +++ b/website/source/docs/configuration/listener/tcp.html.md @@ -34,6 +34,10 @@ advertise the correct address to other nodes. they need to hop through a TCP load balancer or some other scheme in order to talk. +- `max_request_size` `(int: 33554432)` – Specifies a hard maximum allowed + request size, in bytes. Defaults to 32 MB. Specifying a number less than or + equal to `0` turns off limiting altogether. + - `proxy_protocol_behavior` `(string: "") – When specified, turns on the PROXY protocol for the listener. Accepted Values: