Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ext): added hsts middleware #29

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,99 +747,6 @@ func TestDataBindOnHtml(t *testing.T) {

}

func TestMiddleware(t *testing.T) {
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
defer srv.Close()

i := 0

app := New(WithMux(mux))
app.Use(func(next HandleFunc) HandleFunc {
return func(c *Context) error {
i++
c.WriteHeader("X-M1", strconv.Itoa(i))

return next(c)
}
}, func(next HandleFunc) HandleFunc {
return func(c *Context) error {
i++
c.WriteHeader("X-M2", strconv.Itoa(i))
return next(c)
}
})

app.Use(func(next HandleFunc) HandleFunc {
return func(c *Context) error {
i++
c.WriteHeader("X-M3", strconv.Itoa(i))
user := c.Request().Header.Get("X-User")
if user == "" {
c.WriteStatus(http.StatusUnauthorized)
return ErrCancelled
}

if user != "yaitoo" {
c.WriteStatus(http.StatusForbidden)
return ErrCancelled
}

return next(c)
}
})

app.Get("/", func(c *Context) error {
return c.View(nil)
})

go app.Start()
defer app.Close()

req, err := http.NewRequest("GET", srv.URL+"/", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)

require.Equal(t, "1", resp.Header.Get("X-M1"))
require.Equal(t, "2", resp.Header.Get("X-M2"))
require.Equal(t, "3", resp.Header.Get("X-M3"))
_, err = io.Copy(io.Discard, resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
resp.Body.Close()

i = 0
req, err = http.NewRequest("GET", srv.URL+"/", nil)
req.Header.Set("X-User", "xun")
require.NoError(t, err)
resp, err = client.Do(req)
require.NoError(t, err)

require.Equal(t, "1", resp.Header.Get("X-M1"))
require.Equal(t, "2", resp.Header.Get("X-M2"))
require.Equal(t, "3", resp.Header.Get("X-M3"))
_, err = io.Copy(io.Discard, resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
resp.Body.Close()

i = 0
req, err = http.NewRequest("GET", srv.URL+"/", nil)
req.Header.Set("X-User", "yaitoo")
require.NoError(t, err)
resp, err = client.Do(req)
require.NoError(t, err)

require.Equal(t, "1", resp.Header.Get("X-M1"))
require.Equal(t, "2", resp.Header.Get("X-M2"))
require.Equal(t, "3", resp.Header.Get("X-M3"))
_, err = io.Copy(io.Discard, resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
resp.Body.Close()
}

func TestUnhandledError(t *testing.T) {
fsys := &fstest.MapFS{
"public/skin.css": &fstest.MapFile{},
Expand Down
77 changes: 77 additions & 0 deletions ext/hsts/hsts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package hsts

import (
"net"
"net/http"
"strconv"
"time"

"github.com/yaitoo/xun"
)

// Enable sets the Strict-Transport-Security header with the given maxAge,
// includeSubdomains and preload values.
//
// The Strict-Transport-Security header is used to inform browsers that the site
// should only be accessed over HTTPS, and that any HTTP requests should be
// automatically rewritten as HTTPS.
//
// maxAge is the maximum age of the header in seconds.
//
// includeSubdomains will include all subdomains of the current domain in the
// header.
//
// preload will add the preload directive to the header, which allows the site
// to be included in the HSTS preload list.
//
// The HSTS preload list is a list of sites that are known to be HTTPS-only, and
// are included in the browser's HSTS list by default. This allows the browser to
// immediately switch to HTTPS for these sites, without having to wait for the
// first request to complete.
func Enable(maxAge time.Duration, includeSubdomains, preload bool) xun.Middleware {
return func(next xun.HandleFunc) xun.HandleFunc {
return func(c *xun.Context) error {
r := c.Request()

isHTTPS := false
// Check X-Forwarded-Proto header first
forwardedProto := r.Header.Get("X-Forwarded-Proto")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): The X-Forwarded-Proto header should only be trusted from known proxy IPs to prevent spoofing.

Consider adding a configuration option to specify trusted proxy IPs and only accept the X-Forwarded-Proto header from these addresses.

if forwardedProto != "" {
isHTTPS = forwardedProto == "https"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): The X-Forwarded-Proto header should only be trusted from known proxies

Consider adding validation to ensure this header is only processed when coming from trusted proxy IPs. Otherwise, malicious clients could bypass HTTPS enforcement by spoofing this header.

} else {
// Fall back to checking direct protocol
isHTTPS = r.TLS != nil
}

if isHTTPS && (r.Method == "GET" || r.Method == "HEAD") {
cnlangzi marked this conversation as resolved.
Show resolved Hide resolved
target := "https://" + stripPort(r.Host) + r.URL.RequestURI()

if maxAge <= 0 {
maxAge = 365 * 24 * time.Hour
}

v := "max-age=" + strconv.FormatInt(int64(maxAge/time.Second), 10)
if includeSubdomains {
v += "; includeSubDomains"
}
if preload {
v += "; preload"
}
c.WriteHeader("Strict-Transport-Security", v)

c.Redirect(target, http.StatusFound)
return xun.ErrCancelled
}

return next(c)
}
}
}

func stripPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}
167 changes: 167 additions & 0 deletions ext/hsts/hsts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package hsts

import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/yaitoo/xun"
)

func TestHstsMiddleware(t *testing.T) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Suggest adding tests for HTTPS requests.

It appears that the tests primarily focus on HTTP to HTTPS redirection. It's important to also verify the behavior when a request is already made over HTTPS. Specifically, ensure that the HSTS header is correctly set and that no redirection occurs in this scenario.

Suggested implementation:

func TestHstsMiddleware(t *testing.T) {
	tests := []struct {
		name           string
		scheme         string
		expectedStatus int
		checkRedirect bool
		checkHSTS     bool
	}{
		{
			name:           "HTTP request should redirect to HTTPS",
			scheme:         "http",
			expectedStatus: http.StatusMovedPermanently,
			checkRedirect: true,
			checkHSTS:     false,
		},
		{
			name:           "HTTPS request should include HSTS header",
			scheme:         "https",
			expectedStatus: http.StatusOK,
			checkRedirect: false,
			checkHSTS:     true,
		},
	}

	tr := http.DefaultTransport.(*http.Transport).Clone()
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := http.Client{
				Transport: tr,
				CheckRedirect: func(req *http.Request, via []*http.Request) error {
					return http.ErrUseLastResponse
				},
			}

			server := httptest.NewTLSServer(xun.HSTS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				w.WriteHeader(http.StatusOK)
			})))
			defer server.Close()

			serverURL, _ := url.Parse(server.URL)
			req, _ := http.NewRequest("GET", fmt.Sprintf("%s://%s", tt.scheme, serverURL.Host), nil)

			resp, err := c.Do(req)
			require.NoError(t, err)
			defer resp.Body.Close()

			require.Equal(t, tt.expectedStatus, resp.StatusCode)

			if tt.checkRedirect {
				location := resp.Header.Get("Location")
				require.True(t, strings.HasPrefix(location, "https://"))
			}

			if tt.checkHSTS {
				hstsHeader := resp.Header.Get("Strict-Transport-Security")
				require.NotEmpty(t, hstsHeader)
				require.Contains(t, hstsHeader, "max-age=")
			}
		})
	}

You'll need to add these imports if they're not already present:

  • "fmt"
  • "net/http"
  • "crypto/tls"
  • "context"
  • "net"

The test now uses a table-driven approach to test both HTTP and HTTPS scenarios. For the HTTPS case, it verifies:

  1. The response status is 200 OK (no redirect)
  2. The HSTS header is present
  3. The HSTS header contains the required max-age directive

You may need to adjust the expected status codes and header values based on your specific HSTS middleware implementation.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Suggest adding tests for different HTTP methods.

The current tests only cover GET requests. It would be beneficial to add tests for other HTTP methods like POST, PUT, DELETE, etc., to ensure the middleware behaves correctly in all scenarios. For non-GET/HEAD requests, the middleware should not redirect and HSTS header should not be set.

Suggested implementation:

func TestHstsMiddleware(t *testing.T) {
	testCases := []struct {
		name           string
		method         string
		expectedStatus int
		shouldRedirect bool
		shouldHaveHSTS bool
	}{
		{
			name:           "GET request",
			method:         http.MethodGet,
			expectedStatus: http.StatusOK,
			shouldRedirect: true,
			shouldHaveHSTS: true,
		},
		{
			name:           "HEAD request",
			method:         http.MethodHead,
			expectedStatus: http.StatusOK,
			shouldRedirect: true,
			shouldHaveHSTS: true,
		},
		{
			name:           "POST request",
			method:         http.MethodPost,
			expectedStatus: http.StatusOK,
			shouldRedirect: false,
			shouldHaveHSTS: false,
		},
		{
			name:           "PUT request",
			method:         http.MethodPut,
			expectedStatus: http.StatusOK,
			shouldRedirect: false,
			shouldHaveHSTS: false,
		},
		{
			name:           "DELETE request",
			method:         http.MethodDelete,
			expectedStatus: http.StatusOK,
			shouldRedirect: false,
			shouldHaveHSTS: false,
		},
	}

	tr := http.DefaultTransport.(*http.Transport).Clone()

You'll also need to:

  1. Modify the test implementation to iterate over the test cases
  2. Create requests using the specified method for each test case
  3. Add assertions to verify:
    • Redirect behavior based on shouldRedirect
    • HSTS header presence based on shouldHaveHSTS
    • Expected status code

The exact implementation will depend on how the rest of the test is structured, but the test cases provide a framework for comprehensive HTTP method testing.


tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // skipcq: GSC-G402,GO-S1020
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { // skipcq: RVV-B0012
if strings.HasPrefix(addr, "abc.com") {
return net.Dial("tcp", strings.TrimPrefix(addr, "abc.com"))
}
return net.Dial("tcp", addr)
}

c := http.Client{
Transport: tr,
CheckRedirect: func(req *http.Request, via []*http.Request) error { // skipcq: RVV-B0012
return http.ErrUseLastResponse
},
}

t.Run("max_age_should_work", func(t *testing.T) {
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
defer srv.Close()

u, err := url.Parse(srv.URL)
require.NoError(t, err)

l := "https://" + u.Hostname() + "/"
app := xun.New(xun.WithMux(mux))

app.Use(Enable(1*time.Hour, false, false))

app.Get("/", func(c *xun.Context) error {
return c.View(nil)
})

req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode)
require.Equal(t, l, resp.Header.Get("Location"))
require.Equal(t, "max-age=3600", resp.Header.Get("Strict-Transport-Security")) // default MaxAge is 1 year
})

t.Run("invalid_max_age_should_work", func(t *testing.T) {
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
defer srv.Close()

u, err := url.Parse(srv.URL)
require.NoError(t, err)

l := "https://" + u.Hostname() + "/"
app := xun.New(xun.WithMux(mux))

app.Use(Enable(0*time.Hour, false, false))

app.Get("/", func(c *xun.Context) error {
return c.View(nil)
})

req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode)
require.Equal(t, l, resp.Header.Get("Location"))
require.Equal(t, "max-age=31536000", resp.Header.Get("Strict-Transport-Security"))
})

t.Run("max_age_includesubdomains_should_work", func(t *testing.T) {
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
defer srv.Close()

u, err := url.Parse(srv.URL)
require.NoError(t, err)

l := "https://" + u.Hostname() + "/"
app := xun.New(xun.WithMux(mux))

app.Use(Enable(1*time.Hour, true, false))

app.Get("/", func(c *xun.Context) error {
return c.View(nil)
})

req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode)
require.Equal(t, l, resp.Header.Get("Location"))
require.Equal(t, "max-age=3600; includeSubDomains", resp.Header.Get("Strict-Transport-Security"))
})

t.Run("max_age_includesubdomains_preload_should_work", func(t *testing.T) {
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
defer srv.Close()

u, err := url.Parse(srv.URL)
require.NoError(t, err)

l := "https://" + u.Hostname() + "/"
app := xun.New(xun.WithMux(mux))

app.Use(Enable(1*time.Hour, true, true))

app.Get("/", func(c *xun.Context) error {
return c.View(nil)
})

req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode)
require.Equal(t, l, resp.Header.Get("Location"))
require.Equal(t, "max-age=3600; includeSubDomains; preload", resp.Header.Get("Strict-Transport-Security"))
})

t.Run("without_port_should_work", func(t *testing.T) {
mux := http.NewServeMux()

srv := &http.Server{ // skipcq: GO-S2112
Addr: ":80",
Handler: mux,
}
defer srv.Close()
go srv.ListenAndServe() // nolint: errcheck

l := "https://abc.com/"
app := xun.New(xun.WithMux(mux))

app.Use(Enable(1*time.Hour, true, true))

app.Get("/", func(c *xun.Context) error {
return c.View(nil)
})

req, err := http.NewRequest(http.MethodGet, "http://abc.com/", nil) // skipcq: GO-S1028
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode)
require.Equal(t, l, resp.Header.Get("Location"))
require.Equal(t, "max-age=3600; includeSubDomains; preload", resp.Header.Get("Strict-Transport-Security"))
})
}
Loading
Loading