-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(ext): added hsts middleware #29
Changes from 2 commits
b0c598d
edd41fa
4c62c54
34b46cb
b7dcc6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package hsts | ||
|
||
import ( | ||
"net" | ||
"net/http" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/yaitoo/xun" | ||
) | ||
|
||
// Enable sets the Strict-Transport-Security header with the given maxAge, | ||
// includeSubdomains and preload values. | ||
// | ||
// The Strict-Transport-Security header is used to inform browsers that the site | ||
// should only be accessed over HTTPS, and that any HTTP requests should be | ||
// automatically rewritten as HTTPS. | ||
// | ||
// maxAge is the maximum age of the header in seconds. | ||
// | ||
// includeSubdomains will include all subdomains of the current domain in the | ||
// header. | ||
// | ||
// preload will add the preload directive to the header, which allows the site | ||
// to be included in the HSTS preload list. | ||
// | ||
// The HSTS preload list is a list of sites that are known to be HTTPS-only, and | ||
// are included in the browser's HSTS list by default. This allows the browser to | ||
// immediately switch to HTTPS for these sites, without having to wait for the | ||
// first request to complete. | ||
func Enable(maxAge time.Duration, includeSubdomains, preload bool) xun.Middleware { | ||
return func(next xun.HandleFunc) xun.HandleFunc { | ||
return func(c *xun.Context) error { | ||
r := c.Request() | ||
|
||
isHTTPS := false | ||
// Check X-Forwarded-Proto header first | ||
forwardedProto := r.Header.Get("X-Forwarded-Proto") | ||
if forwardedProto != "" { | ||
isHTTPS = forwardedProto == "https" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚨 issue (security): The X-Forwarded-Proto header should only be trusted from known proxies Consider adding validation to ensure this header is only processed when coming from trusted proxy IPs. Otherwise, malicious clients could bypass HTTPS enforcement by spoofing this header. |
||
} else { | ||
// Fall back to checking direct protocol | ||
isHTTPS = r.TLS != nil | ||
} | ||
|
||
if isHTTPS && (r.Method == "GET" || r.Method == "HEAD") { | ||
cnlangzi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
target := "https://" + stripPort(r.Host) + r.URL.RequestURI() | ||
|
||
if maxAge <= 0 { | ||
maxAge = 365 * 24 * time.Hour | ||
} | ||
|
||
v := "max-age=" + strconv.FormatInt(int64(maxAge/time.Second), 10) | ||
if includeSubdomains { | ||
v += "; includeSubDomains" | ||
} | ||
if preload { | ||
v += "; preload" | ||
} | ||
c.WriteHeader("Strict-Transport-Security", v) | ||
|
||
c.Redirect(target, http.StatusFound) | ||
return xun.ErrCancelled | ||
} | ||
|
||
return next(c) | ||
} | ||
} | ||
} | ||
|
||
func stripPort(hostPort string) string { | ||
host, _, err := net.SplitHostPort(hostPort) | ||
if err != nil { | ||
return hostPort | ||
} | ||
return host | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package hsts | ||
|
||
import ( | ||
"context" | ||
"crypto/tls" | ||
"net" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/yaitoo/xun" | ||
) | ||
|
||
func TestHstsMiddleware(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Suggest adding tests for HTTPS requests. It appears that the tests primarily focus on HTTP to HTTPS redirection. It's important to also verify the behavior when a request is already made over HTTPS. Specifically, ensure that the HSTS header is correctly set and that no redirection occurs in this scenario. Suggested implementation: func TestHstsMiddleware(t *testing.T) {
tests := []struct {
name string
scheme string
expectedStatus int
checkRedirect bool
checkHSTS bool
}{
{
name: "HTTP request should redirect to HTTPS",
scheme: "http",
expectedStatus: http.StatusMovedPermanently,
checkRedirect: true,
checkHSTS: false,
},
{
name: "HTTPS request should include HSTS header",
scheme: "https",
expectedStatus: http.StatusOK,
checkRedirect: false,
checkHSTS: true,
},
}
tr := http.DefaultTransport.(*http.Transport).Clone() for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := http.Client{
Transport: tr,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
server := httptest.NewTLSServer(xun.HSTS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
defer server.Close()
serverURL, _ := url.Parse(server.URL)
req, _ := http.NewRequest("GET", fmt.Sprintf("%s://%s", tt.scheme, serverURL.Host), nil)
resp, err := c.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, tt.expectedStatus, resp.StatusCode)
if tt.checkRedirect {
location := resp.Header.Get("Location")
require.True(t, strings.HasPrefix(location, "https://"))
}
if tt.checkHSTS {
hstsHeader := resp.Header.Get("Strict-Transport-Security")
require.NotEmpty(t, hstsHeader)
require.Contains(t, hstsHeader, "max-age=")
}
})
} You'll need to add these imports if they're not already present:
The test now uses a table-driven approach to test both HTTP and HTTPS scenarios. For the HTTPS case, it verifies:
You may need to adjust the expected status codes and header values based on your specific HSTS middleware implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Suggest adding tests for different HTTP methods. The current tests only cover GET requests. It would be beneficial to add tests for other HTTP methods like POST, PUT, DELETE, etc., to ensure the middleware behaves correctly in all scenarios. For non-GET/HEAD requests, the middleware should not redirect and HSTS header should not be set. Suggested implementation: func TestHstsMiddleware(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
shouldRedirect bool
shouldHaveHSTS bool
}{
{
name: "GET request",
method: http.MethodGet,
expectedStatus: http.StatusOK,
shouldRedirect: true,
shouldHaveHSTS: true,
},
{
name: "HEAD request",
method: http.MethodHead,
expectedStatus: http.StatusOK,
shouldRedirect: true,
shouldHaveHSTS: true,
},
{
name: "POST request",
method: http.MethodPost,
expectedStatus: http.StatusOK,
shouldRedirect: false,
shouldHaveHSTS: false,
},
{
name: "PUT request",
method: http.MethodPut,
expectedStatus: http.StatusOK,
shouldRedirect: false,
shouldHaveHSTS: false,
},
{
name: "DELETE request",
method: http.MethodDelete,
expectedStatus: http.StatusOK,
shouldRedirect: false,
shouldHaveHSTS: false,
},
}
tr := http.DefaultTransport.(*http.Transport).Clone() You'll also need to:
The exact implementation will depend on how the rest of the test is structured, but the test cases provide a framework for comprehensive HTTP method testing. |
||
|
||
tr := http.DefaultTransport.(*http.Transport).Clone() | ||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // skipcq: GSC-G402,GO-S1020 | ||
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { // skipcq: RVV-B0012 | ||
if strings.HasPrefix(addr, "abc.com") { | ||
return net.Dial("tcp", strings.TrimPrefix(addr, "abc.com")) | ||
} | ||
return net.Dial("tcp", addr) | ||
} | ||
|
||
c := http.Client{ | ||
Transport: tr, | ||
CheckRedirect: func(req *http.Request, via []*http.Request) error { // skipcq: RVV-B0012 | ||
return http.ErrUseLastResponse | ||
}, | ||
} | ||
|
||
t.Run("max_age_should_work", func(t *testing.T) { | ||
mux := http.NewServeMux() | ||
srv := httptest.NewServer(mux) | ||
defer srv.Close() | ||
|
||
u, err := url.Parse(srv.URL) | ||
require.NoError(t, err) | ||
|
||
l := "https://" + u.Hostname() + "/" | ||
app := xun.New(xun.WithMux(mux)) | ||
|
||
app.Use(Enable(1*time.Hour, false, false)) | ||
|
||
app.Get("/", func(c *xun.Context) error { | ||
return c.View(nil) | ||
}) | ||
|
||
req, err := http.NewRequest(http.MethodGet, srv.URL, nil) | ||
require.NoError(t, err) | ||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusFound, resp.StatusCode) | ||
require.Equal(t, l, resp.Header.Get("Location")) | ||
require.Equal(t, "max-age=3600", resp.Header.Get("Strict-Transport-Security")) // default MaxAge is 1 year | ||
}) | ||
|
||
t.Run("invalid_max_age_should_work", func(t *testing.T) { | ||
mux := http.NewServeMux() | ||
srv := httptest.NewServer(mux) | ||
defer srv.Close() | ||
|
||
u, err := url.Parse(srv.URL) | ||
require.NoError(t, err) | ||
|
||
l := "https://" + u.Hostname() + "/" | ||
app := xun.New(xun.WithMux(mux)) | ||
|
||
app.Use(Enable(0*time.Hour, false, false)) | ||
|
||
app.Get("/", func(c *xun.Context) error { | ||
return c.View(nil) | ||
}) | ||
|
||
req, err := http.NewRequest(http.MethodGet, srv.URL, nil) | ||
require.NoError(t, err) | ||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusFound, resp.StatusCode) | ||
require.Equal(t, l, resp.Header.Get("Location")) | ||
require.Equal(t, "max-age=31536000", resp.Header.Get("Strict-Transport-Security")) | ||
}) | ||
|
||
t.Run("max_age_includesubdomains_should_work", func(t *testing.T) { | ||
mux := http.NewServeMux() | ||
srv := httptest.NewServer(mux) | ||
defer srv.Close() | ||
|
||
u, err := url.Parse(srv.URL) | ||
require.NoError(t, err) | ||
|
||
l := "https://" + u.Hostname() + "/" | ||
app := xun.New(xun.WithMux(mux)) | ||
|
||
app.Use(Enable(1*time.Hour, true, false)) | ||
|
||
app.Get("/", func(c *xun.Context) error { | ||
return c.View(nil) | ||
}) | ||
|
||
req, err := http.NewRequest(http.MethodGet, srv.URL, nil) | ||
require.NoError(t, err) | ||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusFound, resp.StatusCode) | ||
require.Equal(t, l, resp.Header.Get("Location")) | ||
require.Equal(t, "max-age=3600; includeSubDomains", resp.Header.Get("Strict-Transport-Security")) | ||
}) | ||
|
||
t.Run("max_age_includesubdomains_preload_should_work", func(t *testing.T) { | ||
mux := http.NewServeMux() | ||
srv := httptest.NewServer(mux) | ||
defer srv.Close() | ||
|
||
u, err := url.Parse(srv.URL) | ||
require.NoError(t, err) | ||
|
||
l := "https://" + u.Hostname() + "/" | ||
app := xun.New(xun.WithMux(mux)) | ||
|
||
app.Use(Enable(1*time.Hour, true, true)) | ||
|
||
app.Get("/", func(c *xun.Context) error { | ||
return c.View(nil) | ||
}) | ||
|
||
req, err := http.NewRequest(http.MethodGet, srv.URL, nil) | ||
require.NoError(t, err) | ||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusFound, resp.StatusCode) | ||
require.Equal(t, l, resp.Header.Get("Location")) | ||
require.Equal(t, "max-age=3600; includeSubDomains; preload", resp.Header.Get("Strict-Transport-Security")) | ||
}) | ||
|
||
t.Run("without_port_should_work", func(t *testing.T) { | ||
mux := http.NewServeMux() | ||
|
||
srv := &http.Server{ // skipcq: GO-S2112 | ||
Addr: ":80", | ||
Handler: mux, | ||
} | ||
defer srv.Close() | ||
go srv.ListenAndServe() // nolint: errcheck | ||
|
||
l := "https://abc.com/" | ||
app := xun.New(xun.WithMux(mux)) | ||
|
||
app.Use(Enable(1*time.Hour, true, true)) | ||
|
||
app.Get("/", func(c *xun.Context) error { | ||
return c.View(nil) | ||
}) | ||
|
||
req, err := http.NewRequest(http.MethodGet, "http://abc.com/", nil) // skipcq: GO-S1028 | ||
require.NoError(t, err) | ||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusFound, resp.StatusCode) | ||
require.Equal(t, l, resp.Header.Get("Location")) | ||
require.Equal(t, "max-age=3600; includeSubDomains; preload", resp.Header.Get("Strict-Transport-Security")) | ||
}) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚨 issue (security): The X-Forwarded-Proto header should only be trusted from known proxy IPs to prevent spoofing.
Consider adding a configuration option to specify trusted proxy IPs and only accept the X-Forwarded-Proto header from these addresses.