Skip to content

Commit

Permalink
Verify forwarded proto in service test
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmcconnell committed Jan 7, 2025
1 parent d7444e5 commit 84d918d
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions internal/server/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ func TestService_RedirectToHTTPSWhenTLSRequired(t *testing.T) {
}

func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) {
service := testCreateService(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions)
var forwardedProto string

service := testCreateServiceWithHandler(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
forwardedProto = r.Header.Get("X-Forwarded-Proto")
}),
)

require.True(t, service.options.TLSEnabled)

Expand All @@ -53,12 +59,14 @@ func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) {
service.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
assert.Equal(t, "http", forwardedProto)

req = httptest.NewRequest(http.MethodGet, "https://example.com", nil)
w = httptest.NewRecorder()
service.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
assert.Equal(t, "https", forwardedProto)
}

func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) {
Expand Down Expand Up @@ -154,7 +162,13 @@ func TestService_MarshallingState(t *testing.T) {
}

func testCreateService(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions) *Service {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
return testCreateServiceWithHandler(t, hosts, options, targetOptions,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
}

func testCreateServiceWithHandler(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions, handler http.Handler) *Service {
server := httptest.NewServer(handler)
t.Cleanup(server.Close)

serverURL, err := url.Parse(server.URL)
Expand Down

0 comments on commit 84d918d

Please sign in to comment.