Skip to content

Commit

Permalink
Merge pull request #6714 from TheThingsNetwork/feature/rate-limit-http
Browse files Browse the repository at this point in the history
Improve HTTP rate limiting classes and keys
  • Loading branch information
adriansmares authored Nov 21, 2023
2 parents a540687 + e1da170 commit 69259b5
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 41 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ For details about compatibility between different releases, see the **Commitment

### Added

- Rate limiting classes for individual HTTP paths.
- Rate limiting keys for HTTP endpoints now contain the caller API key ID when available. The caller IP is still available as a fallback.

### Changed

- Server side events replaced with single socket connection using the native WebSocket API.
Expand Down
12 changes: 0 additions & 12 deletions pkg/basicstation/cups/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"go.thethings.network/lorawan-stack/v3/pkg/web"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

// Server implements the Basic Station Configuration and Update Server.
Expand Down Expand Up @@ -184,17 +183,6 @@ func (s *Server) RegisterRoutes(web *web.Server) {
router.Path("/update-info").HandlerFunc(s.UpdateInfo).Methods(http.MethodPost)
}

func getContext(r *http.Request) context.Context {
ctx := r.Context()
md := metadata.New(map[string]string{
"authorization": r.Header.Get("Authorization"),
})
if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
return metadata.NewIncomingContext(ctx, md)
}

var errNoTrust = errors.DefineInternal("no_trust", "no trusted certificate found")

// parseAddress parses a CUPS or LNS address.
Expand Down
2 changes: 1 addition & 1 deletion pkg/console/internal/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func (h *eventsHandler) RegisterRoutes(server *web.Server) {
router := server.APIRouter().PathPrefix(ttnpb.HTTPAPIPrefix + "/console/internal/events/").Subrouter()
router.Use(
mux.MiddlewareFunc(webmiddleware.Namespace("console/internal/events")),
ratelimit.HTTPMiddleware(h.component.RateLimiter(), "http:console:internal:events"),
mux.MiddlewareFunc(middleware.ProtocolAuthentication(authorizationProtocolPrefix)),
mux.MiddlewareFunc(webmiddleware.Metadata("Authorization")),
ratelimit.HTTPMiddleware(h.component.RateLimiter(), "http:console:internal:events"),
)
router.Path("/").HandlerFunc(h.handleEvents).Methods(http.MethodGet)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/gatewayconfigurationserver/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func (s *Server) RegisterRoutes(server *web.Server) {
router := server.Prefix(ttnpb.HTTPAPIPrefix + "/gcs/gateways/{gateway_id}/").Subrouter()
router.Use(
mux.MiddlewareFunc(webmiddleware.Namespace("gatewayconfigurationserver")),
ratelimit.HTTPMiddleware(s.Component.RateLimiter(), "http:gcs"),
mux.MiddlewareFunc(webmiddleware.Metadata("Authorization")),
ratelimit.HTTPMiddleware(s.Component.RateLimiter(), "http:gcs"),
validateAndFillIDs,
)
if s.config.RequireAuth {
Expand Down
2 changes: 1 addition & 1 deletion pkg/gatewayconfigurationserver/v2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ func (s *Server) RegisterRoutes(server *web.Server) {

middleware := []webmiddleware.MiddlewareFunc{
webmiddleware.Namespace("gatewayconfigurationserver/v2"),
ratelimit.HTTPMiddleware(s.component.RateLimiter(), "http:gcs"),
rewriteAuthorization,
webmiddleware.Metadata("Authorization"),
ratelimit.HTTPMiddleware(s.component.RateLimiter(), "http:gcs"),
}

router.Handle(
Expand Down
12 changes: 0 additions & 12 deletions pkg/ratelimit/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"fmt"

"go.thethings.network/lorawan-stack/v3/pkg/auth"
clusterauth "go.thethings.network/lorawan-stack/v3/pkg/auth/cluster"
"go.thethings.network/lorawan-stack/v3/pkg/log"
"go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata"
Expand All @@ -44,17 +43,6 @@ func grpcIsClusterAuthCall(ctx context.Context) bool {
return rpcmetadata.FromIncomingContext(ctx).AuthType == clusterauth.AuthType
}

func grpcAuthTokenID(ctx context.Context) string {
if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" {
_, id, _, err := auth.SplitToken(authValue)
if err != nil {
return "unauthenticated"
}
return id
}
return "unauthenticated"
}

// UnaryServerInterceptor returns a gRPC unary server interceptor that rate limits incoming gRPC requests.
func UnaryServerInterceptor(limiter Interface) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
Expand Down
13 changes: 3 additions & 10 deletions pkg/ratelimit/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ func (ss *serverStream) SetHeader(md metadata.MD) error {
return nil
}

func grpcTokenContext(authTokenID string) context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID),
))
}

func grpcClusterContext() context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("%s %X", clusterauth.AuthType, []byte{0x00, 0x01, 0x02}),
Expand All @@ -82,7 +76,6 @@ func TestGRPC(t *testing.T) {

const (
unaryMethod = "/Service/UnaryMethod"
authTokenID = "my-token-id"
streamMethod = "/Service/StreamMethod"
)

Expand Down Expand Up @@ -157,7 +150,7 @@ func TestGRPC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
intercept := ratelimit.UnaryServerInterceptor(tc.limiter)

ctx := grpcTokenContext(authTokenID)
ctx := tokenContext(authTokenID)
if tc.cluster {
ctx = grpcClusterContext()
}
Expand Down Expand Up @@ -219,7 +212,7 @@ func TestGRPC(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
intercept := ratelimit.StreamServerInterceptor(tc.limiter)
ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)}
ss := &serverStream{t: t, ctx: tokenContext(authTokenID)}
if tc.cluster {
ss.ctx = grpcClusterContext()
}
Expand All @@ -238,7 +231,7 @@ func TestGRPC(t *testing.T) {
"grpc:stream:up": &mockLimiter{},
}
intercept := ratelimit.StreamServerInterceptor(limiter)
ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)}
ss := &serverStream{t: t, ctx: tokenContext(authTokenID)}
info := &grpc.StreamServerInfo{FullMethod: streamMethod}

keyFromFirstStream := ""
Expand Down
36 changes: 36 additions & 0 deletions pkg/ratelimit/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,42 @@ func TestHTTP(t *testing.T) {
a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"})
})

t.Run("PathTemplate", func(t *testing.T) {
limiter.limit = false
limiter.result = ratelimit.Result{Limit: 10}

restore := ratelimit.SetPathTemplate(func(r *http.Request) (string, bool) {
return "/path/{id}", true
})
defer restore()

rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httpRequest("/path/123", "10.10.10.10"))

a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10")
a.So(rec.Result().StatusCode, should.Equal, http.StatusOK)

a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path/123")
a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "10.10.10.10")
a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{"http:test:/path/{id}", class, "http"})
})

t.Run("AuthToken", func(t *testing.T) {
limiter.limit = false
limiter.result = ratelimit.Result{Limit: 10}

rec := httptest.NewRecorder()
req := httpRequest("/path", "10.10.10.10").WithContext(tokenContext(authTokenID))
handler.ServeHTTP(rec, req)

a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10")
a.So(rec.Result().StatusCode, should.Equal, http.StatusOK)

a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path")
a.So(limiter.calledWithResource.Key(), should.ContainSubstring, authTokenID)
a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"})
})

t.Run("Limit", func(t *testing.T) {
limiter.limit = true
rec := httptest.NewRecorder()
Expand Down
44 changes: 40 additions & 4 deletions pkg/ratelimit/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import (
"net"
"net/http"

"github.com/gorilla/mux"
"go.thethings.network/lorawan-stack/v3/pkg/auth"
"go.thethings.network/lorawan-stack/v3/pkg/events"
"go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata"
"go.thethings.network/lorawan-stack/v3/pkg/ttnpb"
"go.thethings.network/lorawan-stack/v3/pkg/unique"
)
Expand All @@ -44,18 +47,51 @@ type resource struct {
func (r *resource) Key() string { return r.key }
func (r *resource) Classes() []string { return r.classes }

const unauthenticated = "unauthenticated"

func authTokenID(ctx context.Context) string {
if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" {
_, id, _, err := auth.SplitToken(authValue)
if err != nil {
return unauthenticated
}
return id
}
return unauthenticated
}

var pathTemplate = func(r *http.Request) (string, bool) {
route := mux.CurrentRoute(r)
if route == nil {
return "", false
}
pathTemplate, err := route.GetPathTemplate()
if err != nil {
return "", false
}
return pathTemplate, true
}

// httpRequestResource represents an HTTP request. Avoid using directly, use HTTPMiddleware instead.
func httpRequestResource(r *http.Request, class string) Resource {
specificClasses := make([]string, 0, 3)
if template, ok := pathTemplate(r); ok {
specificClasses = append(specificClasses, fmt.Sprintf("%s:%s", class, template))
}
callerInfo := fmt.Sprintf("ip:%s", httpRemoteIP(r))
if authTokenID := authTokenID(r.Context()); authTokenID != unauthenticated {
callerInfo = fmt.Sprintf("token:%s", authTokenID)
}
return &resource{
key: fmt.Sprintf("%s:ip:%s:url:%s", class, httpRemoteIP(r), r.URL.Path),
classes: []string{class, "http"},
key: fmt.Sprintf("%s:%s:url:%s", class, callerInfo, r.URL.Path),
classes: append(specificClasses, class, "http"),
}
}

// grpcMethodResource represents a gRPC request.
func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resource {
key := fmt.Sprintf("grpc:method:%s:%s", fullMethod, grpcEntityFromRequest(ctx, req))
if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" {
if authTokenID := authTokenID(ctx); authTokenID != unauthenticated {
key = fmt.Sprintf("%s:token:%s", key, authTokenID)
}
return &resource{
Expand All @@ -67,7 +103,7 @@ func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resourc
// grpcStreamAcceptResource represents a new gRPC server stream.
func grpcStreamAcceptResource(ctx context.Context, fullMethod string) Resource {
key := fmt.Sprintf("grpc:stream:accept:%s", fullMethod)
if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" {
if authTokenID := authTokenID(ctx); authTokenID != unauthenticated {
key = fmt.Sprintf("%s:token:%s", key, authTokenID)
}
return &resource{
Expand Down
24 changes: 24 additions & 0 deletions pkg/ratelimit/resource_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright © 2023 The Things Network Foundation, The Things Industries B.V.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ratelimit

import "net/http"

// SetPathTemplate sets the path template function for HTTP rate limiting.
func SetPathTemplate(f func(*http.Request) (string, bool)) func() {
old := pathTemplate
pathTemplate = f
return func() { pathTemplate = old }
}
31 changes: 31 additions & 0 deletions pkg/ratelimit/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright © 2023 The Things Network Foundation, The Things Industries B.V.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ratelimit_test

import (
"context"
"fmt"

"go.thethings.network/lorawan-stack/v3/pkg/util/test"
"google.golang.org/grpc/metadata"
)

const authTokenID = "my-token-id"

func tokenContext(authTokenID string) context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID),
))
}

0 comments on commit 69259b5

Please sign in to comment.