Skip to content

Commit

Permalink
Add content encoding with gzip, add cache headers, improve error logg…
Browse files Browse the repository at this point in the history
…ing, add html title config
  • Loading branch information
giftkugel committed Sep 10, 2024
1 parent 1beb2a6 commit 39d067c
Show file tree
Hide file tree
Showing 30 changed files with 315 additions and 175 deletions.
9 changes: 7 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ type Client struct {
type UI struct {
HideFooter bool `yaml:"hideFooter"`
HideLogo bool `yaml:"hideLogo"`
HtmlTitle string `yaml:"htmlTitle"`
Title string `yaml:"title"`
FooterText string `yaml:"footerText"`
LogoImage string `yaml:"logoImage"`
LogoContentType string `yaml:"logoContentType"`
InvalidCredentialsMessage string `yaml:"invalidCredentialsMessage"`
ExpiredLoginMessage string `yaml:"expiredLoginMessage"`
}
Expand Down Expand Up @@ -355,6 +355,11 @@ func (config *Config) GetHideLogo() bool {
return config.UI.HideLogo
}

// GetHtmlTitle returns whether the HTML title shown in the web user interface.
func (config *Config) GetHtmlTitle() string {
return config.UI.HtmlTitle
}

// GetTitle returns whether the title shown in the web user interface.
func (config *Config) GetTitle() string {
return config.UI.Title
Expand Down Expand Up @@ -390,7 +395,7 @@ func (config *Config) GetOidc() bool {

// GetIssuer returns the issuer, either by mirroring from request, from Server configuration or default value.
func (config *Config) GetIssuer(requestData *internalHttp.RequestData) string {
if requestData == nil || requestData.Host == "" || requestData.Scheme == "" {
if requestData == nil || !requestData.Valid() {
return GetOrDefaultString(config.Server.Issuer, "STOPnik")
}
return GetOrDefaultString(config.Server.Issuer, requestData.IssuerString())
Expand Down
4 changes: 4 additions & 0 deletions internal/http/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package http
const (
Location string = "Location"
ContentType string = "Content-Type"
ContentEncoding string = "Content-Encoding"
CacheControl string = "Cache-Control"
ETag string = "ETag"
Authorization string = "Authorization"
AccessControlAllowOrigin string = "Access-Control-Allow-Origin"
AcceptEncoding string = "Accept-Encoding"
XForwardProtocol string = "X-Forwarded-Proto"
XForwardHost string = "X-Forwarded-Host"
XForwardUri string = "X-Forwarded-Uri"
Expand Down
15 changes: 11 additions & 4 deletions internal/http/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@ import (
"net/http"
)

func SendJson(value any, w http.ResponseWriter) error {
return SendJsonWithStatus(value, w, http.StatusOK)
func SendJson(value any, w http.ResponseWriter, r *http.Request) error {
return SendJsonWithStatus(value, http.StatusOK, w, r)
}

func SendJsonWithStatus(value any, w http.ResponseWriter, statusCode int) error {
func SendJsonWithStatus(value any, statusCode int, w http.ResponseWriter, r *http.Request) error {
bytes, tokenMarshalError := json.Marshal(value)
if tokenMarshalError != nil {
return tokenMarshalError
}

requestData := NewRequestData(r)
responseWriter := NewResponseWriter(w, requestData)

responseWriter.SetEncodingHeader()

w.Header().Set(ContentType, ContentTypeJSON)
w.Header().Set(CacheControl, "private, no-store")
w.Header().Set(AccessControlAllowOrigin, "*")
w.WriteHeader(statusCode)
_, writeError := w.Write(bytes)

_, writeError := responseWriter.Write(bytes)
if writeError != nil {
return writeError
}
Expand Down
4 changes: 3 additions & 1 deletion internal/http/json_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"net/http"
"net/http/httptest"
"testing"
)
Expand All @@ -18,7 +19,8 @@ func Test_SendJson(t *testing.T) {
Age: 20,
}

err := SendJson(data, rr)
request := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
err := SendJson(data, rr, request)

if err != nil {
t.Error(err)
Expand Down
57 changes: 45 additions & 12 deletions internal/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
)

type CompressionMethod string

const (
CompressionMethodGZip CompressionMethod = "gzip"
)

var supportedEncodingMethods = []CompressionMethod{CompressionMethodGZip}

type RequestData struct {
Scheme string
Host string
Path string
Query string
Fragment string
scheme string
host string
path string
query string
fragment string
compressed *CompressionMethod
}

func NewRequestData(r *http.Request) *RequestData {
Expand All @@ -31,21 +41,44 @@ func NewRequestData(r *http.Request) *RequestData {
if r.URL.RawFragment != "" {
fragment = "#" + r.URL.RawFragment
}

acceptEncodingHeader := r.Header.Get(AcceptEncoding)
acceptEncoding := strings.Split(acceptEncodingHeader, ", ")

var compress CompressionMethod
for _, encoding := range acceptEncoding {
for _, supported := range supportedEncodingMethods {
if encoding == string(supported) {
compress = supported
break
}
}
}

return &RequestData{
Scheme: scheme,
Host: host,
Path: path,
Query: query,
Fragment: fragment,
scheme: scheme,
host: host,
path: path,
query: query,
fragment: fragment,
compressed: &compress,
}
}

func (r *RequestData) IssuerString() string {
return fmt.Sprintf("%s://%s", r.Scheme, r.Host)
return fmt.Sprintf("%s://%s", r.scheme, r.host)
}

func (r *RequestData) URL() (*url.URL, error) {
uri := fmt.Sprintf("%s://%s%s%s%s", r.Scheme, r.Host, r.Path, r.Query, r.Fragment)
uri := fmt.Sprintf("%s://%s%s%s%s", r.scheme, r.host, r.path, r.query, r.fragment)

return url.Parse(uri)
}

func (r *RequestData) Valid() bool {
return r.host != "" && r.scheme != ""
}

func (r *RequestData) AcceptCompressed() (*CompressionMethod, bool) {
return r.compressed, r.compressed != nil && *r.compressed != ""
}
18 changes: 9 additions & 9 deletions internal/http/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ func Test_RequestData(t *testing.T) {

requestData := NewRequestData(request)

if requestData.Scheme != test.expectedScheme {
t.Errorf("Scheme mismatch. Expected: %s, got: %s", test.expectedScheme, requestData.Scheme)
if requestData.scheme != test.expectedScheme {
t.Errorf("Scheme mismatch. Expected: %s, got: %s", test.expectedScheme, requestData.scheme)
}

if requestData.Host != test.expectedHost {
if requestData.host != test.expectedHost {
t.Errorf("Host mismatch. Expected: %s, got: %s", test.expectedHost, request.Host)
}

if requestData.Path != test.expectedPath {
t.Errorf("Path mismatch. Expected: %s, got: %s", test.expectedPath, requestData.Path)
if requestData.path != test.expectedPath {
t.Errorf("Path mismatch. Expected: %s, got: %s", test.expectedPath, requestData.path)
}

if requestData.Query != test.expectedQuery {
t.Errorf("Query mismatch. Expected: %s, got: %s", test.expectedQuery, requestData.Query)
if requestData.query != test.expectedQuery {
t.Errorf("Query mismatch. Expected: %s, got: %s", test.expectedQuery, requestData.query)
}

if requestData.Fragment != test.expectedFragment {
t.Errorf("Fragment mismatch. Expected: %s, got: %s", test.expectedFragment, requestData.Fragment)
if requestData.fragment != test.expectedFragment {
t.Errorf("Fragment mismatch. Expected: %s, got: %s", test.expectedFragment, requestData.fragment)
}

parsedUrl, parseError := requestData.URL()
Expand Down
53 changes: 53 additions & 0 deletions internal/http/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package http

import (
"compress/gzip"
"github.com/webishdev/stopnik/internal/system"
"net/http"
)

type ResponseWriter struct {
requestData *RequestData
w http.ResponseWriter
headerWritten bool
}

func NewResponseWriter(w http.ResponseWriter, requestData *RequestData) *ResponseWriter {
return &ResponseWriter{
requestData: requestData,
w: w,
headerWritten: false,
}
}

func (rw *ResponseWriter) SetEncodingHeader() {
compressionMethod, acceptCompressed := rw.requestData.AcceptCompressed()
if acceptCompressed {
switch *compressionMethod {
case CompressionMethodGZip:
rw.w.Header().Set(ContentEncoding, string(CompressionMethodGZip))
rw.headerWritten = true
return
}

}
}

func (rw *ResponseWriter) Write(p []byte) (int, error) {
compressionMethod, acceptCompressed := rw.requestData.AcceptCompressed()
if rw.headerWritten && acceptCompressed {
switch *compressionMethod {
case CompressionMethodGZip:
gw := gzip.NewWriter(rw.w)
defer func(gw *gzip.Writer) {
err := gw.Close()
if err != nil {
system.Error(err)
}
}(gw)
return gw.Write(p)
}

}
return rw.w.Write(p)
}
8 changes: 4 additions & 4 deletions internal/oauth2/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ func AuthorizationErrorResponseHandler(w http.ResponseWriter, redirectURL *url.U
w.WriteHeader(http.StatusFound)
}

func TokenErrorResponseHandler(w http.ResponseWriter, errorResponseParameter *TokenErrorResponseParameter) {
TokenErrorStatusResponseHandler(w, http.StatusBadRequest, errorResponseParameter)
func TokenErrorResponseHandler(w http.ResponseWriter, r *http.Request, errorResponseParameter *TokenErrorResponseParameter) {
TokenErrorStatusResponseHandler(w, r, http.StatusBadRequest, errorResponseParameter)
}

func TokenErrorStatusResponseHandler(w http.ResponseWriter, statusCode int, errorResponseParameter *TokenErrorResponseParameter) {
err := internalHttp.SendJsonWithStatus(errorResponseParameter, w, statusCode)
func TokenErrorStatusResponseHandler(w http.ResponseWriter, r *http.Request, statusCode int, errorResponseParameter *TokenErrorResponseParameter) {
err := internalHttp.SendJsonWithStatus(errorResponseParameter, statusCode, w, r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
Expand Down
29 changes: 10 additions & 19 deletions internal/oauth2/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,7 @@ import (
"testing"
)

func Test_Error(t *testing.T) {

testAuthorizationErrorResponseHandler(t)

testTokenErrorResponseHandler(t)

testTokenErrorStatusResponseHandler(t)

testAuthorizationErrorTypeFromString(t)

testTokenErrorTypeFromString(t)

func Test_ErrorNoRedirectUri(t *testing.T) {
t.Run("No redirect uri provided", func(t *testing.T) {
rr := httptest.NewRecorder()

Expand All @@ -33,7 +22,7 @@ func Test_Error(t *testing.T) {
})
}

func testAuthorizationErrorTypeFromString(t *testing.T) {
func Test_AuthorizationErrorTypeFromString(t *testing.T) {
type authorizationErrorTypeParameter struct {
value string
exists bool
Expand Down Expand Up @@ -61,7 +50,7 @@ func testAuthorizationErrorTypeFromString(t *testing.T) {
}
}

func testAuthorizationErrorResponseHandler(t *testing.T) {
func Test_AuthorizationErrorResponseHandler(t *testing.T) {
type errorResponseHandlerParameter struct {
state string
expectedErrorParameter AuthorizationErrorType
Expand Down Expand Up @@ -140,24 +129,26 @@ func testAuthorizationErrorResponseHandler(t *testing.T) {
}
}

func testTokenErrorResponseHandler(t *testing.T) {
func Test_TokenErrorResponseHandler(t *testing.T) {
t.Run("Token error handler", func(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
rr := httptest.NewRecorder()
TokenErrorResponseHandler(rr, &TokenErrorResponseParameter{})
TokenErrorResponseHandler(rr, request, &TokenErrorResponseParameter{})

if rr.Code != http.StatusBadRequest {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest)
}
})
}

func testTokenErrorStatusResponseHandler(t *testing.T) {
func Test_TokenErrorStatusResponseHandler(t *testing.T) {
statusCodes := []int{http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusRequestTimeout}
for _, statusCode := range statusCodes {
testMessage := fmt.Sprintf("Token error handler with status code %d", statusCode)
t.Run(testMessage, func(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
rr := httptest.NewRecorder()
TokenErrorStatusResponseHandler(rr, statusCode, &TokenErrorResponseParameter{})
TokenErrorStatusResponseHandler(rr, request, statusCode, &TokenErrorResponseParameter{})

if rr.Code != statusCode {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, statusCode)
Expand All @@ -166,7 +157,7 @@ func testTokenErrorStatusResponseHandler(t *testing.T) {
}
}

func testTokenErrorTypeFromString(t *testing.T) {
func Test_TokenErrorTypeFromString(t *testing.T) {
type tokenErrorTypeParameter struct {
value string
exists bool
Expand Down
Loading

0 comments on commit 39d067c

Please sign in to comment.