diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 9e23c73..2ac508e 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -44,7 +44,13 @@ func TestService_RedirectToHTTPSWhenTLSRequired(t *testing.T) { } func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) { - service := testCreateService(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions) + var forwardedProto string + + service := testCreateServiceWithHandler(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedProto = r.Header.Get("X-Forwarded-Proto") + }), + ) require.True(t, service.options.TLSEnabled) @@ -53,12 +59,14 @@ func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) { service.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Result().StatusCode) + assert.Equal(t, "http", forwardedProto) req = httptest.NewRequest(http.MethodGet, "https://example.com", nil) w = httptest.NewRecorder() service.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Result().StatusCode) + assert.Equal(t, "https", forwardedProto) } func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) { @@ -154,7 +162,13 @@ func TestService_MarshallingState(t *testing.T) { } func testCreateService(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions) *Service { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + return testCreateServiceWithHandler(t, hosts, options, targetOptions, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) +} + +func testCreateServiceWithHandler(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions, handler http.Handler) *Service { + server := httptest.NewServer(handler) t.Cleanup(server.Close) serverURL, err := url.Parse(server.URL)