diff --git a/main.go b/main.go index 7ebef75f..b9f159a1 100644 --- a/main.go +++ b/main.go @@ -71,24 +71,12 @@ func probeHandler(w http.ResponseWriter, r *http.Request, c *config.Config, logg return } - // If a timeout is configured via the Prometheus header, add it to the request. - var timeoutSeconds float64 - if v := r.Header.Get("X-Prometheus-Scrape-Timeout-Seconds"); v != "" { - var err error - timeoutSeconds, err = strconv.ParseFloat(v, 64) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to parse timeout from Prometheus header: %s", err), http.StatusInternalServerError) - return - } - } - if timeoutSeconds == 0 { - timeoutSeconds = 10 + timeoutSeconds, err := getTimeout(r, module, *timeoutOffset) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to parse timeout from Prometheus header: %s", err), http.StatusInternalServerError) + return } - if module.Timeout.Seconds() < timeoutSeconds && module.Timeout.Seconds() > 0 { - timeoutSeconds = module.Timeout.Seconds() - } - timeoutSeconds -= *timeoutOffset ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second))) defer cancel() r = r.WithContext(ctx) @@ -355,3 +343,26 @@ func run() int { } } + +func getTimeout(r *http.Request, module config.Module, offset float64) (timeoutSeconds float64, err error) { + // If a timeout is configured via the Prometheus header, add it to the request. + if v := r.Header.Get("X-Prometheus-Scrape-Timeout-Seconds"); v != "" { + var err error + timeoutSeconds, err = strconv.ParseFloat(v, 64) + if err != nil { + return 0, err + } + } + if timeoutSeconds == 0 { + timeoutSeconds = 10 + } + + var maxTimeoutSeconds = timeoutSeconds - offset + if module.Timeout.Seconds() < maxTimeoutSeconds && module.Timeout.Seconds() > 0 { + timeoutSeconds = module.Timeout.Seconds() + } else { + timeoutSeconds = maxTimeoutSeconds + } + + return timeoutSeconds, nil +} diff --git a/main_test.go b/main_test.go index 1cbf5a6f..ad69ae4a 100644 --- a/main_test.go +++ b/main_test.go @@ -102,3 +102,38 @@ func TestDebugOutputSecretsHidden(t *testing.T) { t.Errorf("Hidden secret missing from debug output: %v", out) } } + +func TestTimeoutIsSetCorrectly(t *testing.T) { + var tests = []struct { + inModuleTimeout time.Duration + inPrometheusTimeout string + inOffset float64 + outTimeout float64 + }{ + {0 * time.Second, "15", 0.5, 14.5}, + {0 * time.Second, "15", 0, 15}, + {20 * time.Second, "15", 0.5, 14.5}, + {20 * time.Second, "15", 0, 15}, + {5 * time.Second, "15", 0, 5}, + {5 * time.Second, "15", 0.5, 5}, + {10 * time.Second, "", 0.5, 9.5}, + {10 * time.Second, "10", 0.5, 9.5}, + {9500 * time.Millisecond, "", 0.5, 9.5}, + {9500 * time.Millisecond, "", 1, 9}, + {0 * time.Second, "", 0.5, 9.5}, + {0 * time.Second, "", 0, 10}, + } + + for _, v := range tests { + request, _ := http.NewRequest("GET", "", nil) + request.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", v.inPrometheusTimeout) + module := config.Module{ + Timeout: v.inModuleTimeout, + } + + timeout, _ := getTimeout(request, module, v.inOffset) + if timeout != v.outTimeout { + t.Errorf("timeout is incorrect: %v, want %v", timeout, v.outTimeout) + } + } +}