Skip to content

Commit

Permalink
Support all HTTP methods (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
weisdd authored Apr 15, 2022
1 parent 960d7ad commit 1bac2c7
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 94 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# CHANGELOG

## 0.11.2

- Key changes:
- All HTTP methods are allowed now (previously, only POST/GET requests were supported due to technical reasons);
- `VictoriaMetrics/metricsql` bumped from `0.40.0` to `0.41.0`.

## 0.11.1

- Key changes:
Expand Down
5 changes: 5 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ tasks:
default:
- task: test
- task: lint
- task: tidy

test:
cmds:
Expand All @@ -13,3 +14,7 @@ tasks:
lint:
cmds:
- golangci-lint run

tidy:
cmds:
- go mod tidy
1 change: 1 addition & 0 deletions cmd/lfgw/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func (app *application) isNotAPIRequest(path string) bool {
// isUnsafePath returns true if the requested path targets a potentially dangerous endpoint (admin or remote write).
func (app *application) isUnsafePath(path string) bool {
// TODO: move to regexp?
// TODO: more unsafe paths?
return strings.Contains(path, "/admin/tsdb") || strings.Contains(path, "/api/v1/write")
}

Expand Down
46 changes: 17 additions & 29 deletions cmd/lfgw/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (app *application) logMiddleware(next http.Handler) http.Handler {
return
}

// Once r.ParseForm() is called, we need to update ContentLength, otherwise the request will fail
// Once r.ParseForm() is called, we need to update ContentLength, otherwise the request will fail. r.PostForm contains data for PATCH, POST, and PUT requests.
postForm := r.PostForm.Encode()
newBody := strings.NewReader(postForm)
r.ContentLength = newBody.Size()
Expand All @@ -68,27 +68,16 @@ func (app *application) logMiddleware(next http.Handler) http.Handler {
})
}

// prohibitedMethods forbids all the methods aside from "GET" and "POST".
func (app *application) prohibitedMethodsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
w.Header().Set("Allow", fmt.Sprintf("%s, %s", http.MethodGet, http.MethodPost))
app.clientError(w, http.StatusMethodNotAllowed)
return
}
next.ServeHTTP(w, r)
})
}

// prohibitedPaths forbids access to some destinations that should not be proxied.
func (app *application) prohibitedPathsMiddleware(next http.Handler) http.Handler {
// safeModeMiddleware forbids access to some API endpoints if safe mode is enabled
func (app *application) safeModeMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if app.SafeMode && app.isUnsafePath(r.URL.Path) {
hlog.FromRequest(r).Error().Caller().
Msgf("Blocked a request to %s", r.URL.Path)
app.clientError(w, http.StatusForbidden)
return
}

next.ServeHTTP(w, r)
})
}
Expand All @@ -101,6 +90,7 @@ func (app *application) proxyHeadersMiddleware(next http.Handler) http.Handler {
r.Header.Set("X-Forwarded-Proto", r.URL.Scheme)
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
}

next.ServeHTTP(w, r)
})
}
Expand Down Expand Up @@ -205,21 +195,19 @@ func (app *application) rewriteRequestMiddleware(next http.Handler) http.Handler
r.URL.RawQuery = newGetParams
app.enrichDebugLogContext(r, "new_get_params", app.unescapedURLQuery(newGetParams))

// Adjust POST params
// Partially inspired by https://github.com/bitsbeats/prometheus-acls/blob/master/internal/labeler/middleware.go
if r.Method == http.MethodPost {
newPostParams, err := app.prepareQueryParams(&r.PostForm, acl)
if err != nil {
hlog.FromRequest(r).Error().Caller().
Err(err).Msg("")
app.clientError(w, http.StatusBadRequest)
return
}
newBody := strings.NewReader(newPostParams)
r.ContentLength = newBody.Size()
r.Body = io.NopCloser(newBody)
app.enrichDebugLogContext(r, "new_post_params", app.unescapedURLQuery(newPostParams))
// For PATCH, POST, and PUT requests
newPostParams, err := app.prepareQueryParams(&r.PostForm, acl)
if err != nil {
hlog.FromRequest(r).Error().Caller().
Err(err).Msg("")
app.clientError(w, http.StatusBadRequest)
return
}
newBody := strings.NewReader(newPostParams)
r.ContentLength = newBody.Size()
r.Body = io.NopCloser(newBody)
// TODO: the field name is slightly misleading, should, probably, be renamed
app.enrichDebugLogContext(r, "new_post_params", app.unescapedURLQuery(newPostParams))

next.ServeHTTP(w, r)
})
Expand Down
82 changes: 21 additions & 61 deletions cmd/lfgw/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,96 +6,56 @@ import (
"testing"

"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)

func TestProhibitedMethodsMiddleware(t *testing.T) {
logger := zerolog.New(nil)
app := &application{
logger: &logger,
}

tests := []struct {
name string
method string
want int
}{
{
name: "GET",
method: http.MethodGet,
want: http.StatusOK,
},
{
name: "POST",
method: http.MethodGet,
want: http.StatusOK,
},
{
name: "PATCH",
method: http.MethodPatch,
want: http.StatusMethodNotAllowed,
}}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := httptest.NewRecorder()
r, err := http.NewRequest(tt.method, "/", nil)
if err != nil {
t.Fatal(err)
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
})
app.prohibitedMethodsMiddleware(next).ServeHTTP(rr, r)
rs := rr.Result()

if rs.StatusCode != tt.want {
t.Errorf("want %d; got %d", tt.want, rs.StatusCode)
}
defer rs.Body.Close()
})
}
}

func TestProhibitedPathsMiddleware(t *testing.T) {
func TestSafeModeMiddleware(t *testing.T) {
tests := []struct {
name string
path string
method string
safeMode bool
want int
}{
{
name: "safe mode on tsdb",
name: "tsdb (safe mode on)",
path: "/admin/tsdb",
method: http.MethodGet,
safeMode: true,
want: http.StatusForbidden,
},
{
name: "safe mode off tsdb",
name: "tsdb (safe mode off)",
path: "/admin/tsdb",
method: http.MethodGet,
safeMode: false,
want: http.StatusOK,
},
{
name: "safe mode on api write",
name: "api write (safe mode on)",
path: "/api/v1/write",
method: http.MethodGet,
safeMode: true,
want: http.StatusForbidden,
},
{
name: "safe mode off api write",
name: "api write (safe mode off)",
path: "/api/v1/write",
method: http.MethodGet,
safeMode: false,
want: http.StatusOK,
},
{
name: "safe mode on random path",
name: "random path (safe mode on)",
path: "/api/v1/test",
safeMode: false,
method: http.MethodGet,
safeMode: true,
want: http.StatusOK,
},
{
name: "safe mode off random path",
name: "random path (safe mode off)",
path: "/api/v1/test",
method: http.MethodGet,
safeMode: false,
want: http.StatusOK,
},
Expand All @@ -110,19 +70,19 @@ func TestProhibitedPathsMiddleware(t *testing.T) {
}

rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, tt.path, nil)
r, err := http.NewRequest(tt.method, tt.path, nil)
if err != nil {
t.Fatal(err)
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
})
app.prohibitedPathsMiddleware(next).ServeHTTP(rr, r)
app.safeModeMiddleware(next).ServeHTTP(rr, r)
rs := rr.Result()
got := rs.StatusCode

assert.Equal(t, tt.want, got)

if rs.StatusCode != tt.want {
t.Errorf("want %d; got %d", tt.want, rs.StatusCode)
}
defer rs.Body.Close()
})
}
Expand Down
7 changes: 3 additions & 4 deletions cmd/lfgw/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ func (app *application) routes() *mux.Router {
r.Use(app.nonProxiedEndpointsMiddleware)
r.Use(hlog.NewHandler(*app.logger))
r.Use(app.logMiddleware)
r.Use(app.prohibitedMethodsMiddleware)
r.Use(app.proxyHeadersMiddleware)
r.Use(app.oidcModeMiddleware)
// Better to keep it here to see user email in logs
r.Use(app.prohibitedPathsMiddleware)
// Better to keep it here to see user email in logs (for unsafe paths)
r.Use(app.safeModeMiddleware)
r.Use(app.proxyHeadersMiddleware)
r.Use(app.rewriteRequestMiddleware)
r.PathPrefix("/").Handler(app.proxy)
return r
Expand Down

0 comments on commit 1bac2c7

Please sign in to comment.