diff --git a/CONFIGURATION.md b/CONFIGURATION.md index 2d835848..6ac0fb34 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -49,6 +49,11 @@ The other placeholders are specified separately. headers: [ : ... ] + # The maximum body length in bytes that will be read from the server. A value of 0 means no limit. + # If the response includes a Content-Length header, it is NOT validated against this value. This + # setting is only meant to limit the amount of data that you are willing to read from the server. + [ max_response_length: | default = 0 ] + # The compression algorithm to use to decompress the response (gzip, br, deflate, identity). # # If an "Accept-Encoding" header is specified, it MUST be such that the compression algorithm diff --git a/config/config.go b/config/config.go index 07e84c05..5a7c5b10 100644 --- a/config/config.go +++ b/config/config.go @@ -207,6 +207,7 @@ type HTTPProbe struct { Body string `yaml:"body,omitempty"` HTTPClientConfig config.HTTPClientConfig `yaml:"http_client_config,inline"` Compression string `yaml:"compression,omitempty"` + MaxResponseLength int64 `yaml:"max_response_length,omitempty"` } type HeaderMatch struct { @@ -287,6 +288,11 @@ func (s *HTTPProbe) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal((*plain)(s)); err != nil { return err } + + if s.MaxResponseLength < 0 { + return fmt.Errorf(`invalid max_response_length value: %d`, s.MaxResponseLength) + } + if err := s.HTTPClientConfig.Validate(); err != nil { return err } diff --git a/prober/http.go b/prober/http.go index 95c53535..56eff73c 100644 --- a/prober/http.go +++ b/prober/http.go @@ -474,6 +474,14 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr } } + // If there's a configured max_response_length, wrap the body in the response in a http.MaxBytesReader. + // This will read up to MaxResponseLength bytes from the body, and return an error if the response is + // larger. It forwards the Close call to the original resp.Body to make sure the TCP connection is + // correctly shut down. The limit is applied _before_ decompression_. + if httpConfig.MaxResponseLength > 0 { + resp.Body = http.MaxBytesReader(nil, resp.Body, httpConfig.MaxResponseLength) + } + // Since the configuration specifies a compression algorithm, blindly treat the response body as a // compressed payload; if we cannot decompress it it's a failure because the configuration says we // should expect the response to be compressed in that way. diff --git a/prober/http_test.go b/prober/http_test.go index 6679b3c0..9824217b 100644 --- a/prober/http_test.go +++ b/prober/http_test.go @@ -513,6 +513,108 @@ func TestHandlingOfCompressionSetting(t *testing.T) { } } +func TestMaxResponseLength(t *testing.T) { + const max = 128 + + var gzippedPayload bytes.Buffer + enc := gzip.NewWriter(&gzippedPayload) + enc.Write(bytes.Repeat([]byte{'A'}, 2*max)) + enc.Close() + + testcases := map[string]struct { + target string + compression string + expectedMetrics map[string]float64 + expectFailure bool + }{ + "short": { + target: "/short", + expectedMetrics: map[string]float64{ + "probe_http_uncompressed_body_length": float64(max - 1), + "probe_http_content_length": float64(max - 1), + }, + }, + "long": { + target: "/long", + expectFailure: true, + expectedMetrics: map[string]float64{ + "probe_http_content_length": float64(max + 1), + }, + }, + "compressed": { + target: "/compressed", + compression: "gzip", + expectedMetrics: map[string]float64{ + "probe_http_content_length": float64(gzippedPayload.Len()), + "probe_http_uncompressed_body_length": float64(2 * max), + }, + }, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var resp []byte + + switch r.URL.Path { + case "/compressed": + resp = gzippedPayload.Bytes() + w.Header().Add("Content-Encoding", "gzip") + + case "/long": + resp = bytes.Repeat([]byte{'A'}, max+1) + + case "/short": + resp = bytes.Repeat([]byte{'A'}, max-1) + + default: + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Header().Set("Content-Length", strconv.Itoa(len(resp))) + w.WriteHeader(http.StatusOK) + w.Write(resp) + })) + defer ts.Close() + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + registry := prometheus.NewRegistry() + testCTX, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + result := ProbeHTTP( + testCTX, + ts.URL+tc.target, + config.Module{ + Timeout: time.Second, + HTTP: config.HTTPProbe{ + IPProtocolFallback: true, + MaxResponseLength: max, + HTTPClientConfig: pconfig.DefaultHTTPClientConfig, + Compression: tc.compression, + }, + }, + registry, + log.NewNopLogger(), + ) + + switch { + case tc.expectFailure && result: + t.Fatalf("test passed unexpectedly") + case !tc.expectFailure && !result: + t.Fatalf("test failed unexpectedly") + } + + mfs, err := registry.Gather() + if err != nil { + t.Fatal(err) + } + + checkRegistryResults(tc.expectedMetrics, mfs, t) + }) + } +} + func TestRedirectFollowed(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" {