diff --git a/CHANGELOG.md b/CHANGELOG.md index 56a1c8d91f..5f608e0278 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ For details about compatibility between different releases, see the **Commitment ### Added +- Rate limiting classes for individual HTTP paths. +- Rate limiting keys for HTTP endpoints now contain the caller API key ID when available. The caller IP is still available as a fallback. + ### Changed ### Deprecated diff --git a/pkg/basicstation/cups/server.go b/pkg/basicstation/cups/server.go index 323ac8c91f..298fd3c9d5 100644 --- a/pkg/basicstation/cups/server.go +++ b/pkg/basicstation/cups/server.go @@ -33,7 +33,6 @@ import ( "go.thethings.network/lorawan-stack/v3/pkg/web" "golang.org/x/sync/singleflight" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" ) // Server implements the Basic Station Configuration and Update Server. @@ -184,17 +183,6 @@ func (s *Server) RegisterRoutes(web *web.Server) { router.Path("/update-info").HandlerFunc(s.UpdateInfo).Methods(http.MethodPost) } -func getContext(r *http.Request) context.Context { - ctx := r.Context() - md := metadata.New(map[string]string{ - "authorization": r.Header.Get("Authorization"), - }) - if ctxMd, ok := metadata.FromIncomingContext(ctx); ok { - md = metadata.Join(ctxMd, md) - } - return metadata.NewIncomingContext(ctx, md) -} - var errNoTrust = errors.DefineInternal("no_trust", "no trusted certificate found") // parseAddress parses a CUPS or LNS address. diff --git a/pkg/console/internal/events/events.go b/pkg/console/internal/events/events.go index 8cf8d77dc3..3422b160dc 100644 --- a/pkg/console/internal/events/events.go +++ b/pkg/console/internal/events/events.go @@ -62,9 +62,9 @@ func (h *eventsHandler) RegisterRoutes(server *web.Server) { router := server.APIRouter().PathPrefix(ttnpb.HTTPAPIPrefix + "/console/internal/events/").Subrouter() router.Use( mux.MiddlewareFunc(webmiddleware.Namespace("console/internal/events")), - ratelimit.HTTPMiddleware(h.component.RateLimiter(), "http:console:internal:events"), mux.MiddlewareFunc(middleware.ProtocolAuthentication(authorizationProtocolPrefix)), mux.MiddlewareFunc(webmiddleware.Metadata("Authorization")), + ratelimit.HTTPMiddleware(h.component.RateLimiter(), "http:console:internal:events"), ) router.Path("/").HandlerFunc(h.handleEvents).Methods(http.MethodGet) } diff --git a/pkg/gatewayconfigurationserver/http.go b/pkg/gatewayconfigurationserver/http.go index d5578d8c97..3912548da6 100644 --- a/pkg/gatewayconfigurationserver/http.go +++ b/pkg/gatewayconfigurationserver/http.go @@ -38,8 +38,8 @@ func (s *Server) RegisterRoutes(server *web.Server) { router := server.Prefix(ttnpb.HTTPAPIPrefix + "/gcs/gateways/{gateway_id}/").Subrouter() router.Use( mux.MiddlewareFunc(webmiddleware.Namespace("gatewayconfigurationserver")), - ratelimit.HTTPMiddleware(s.Component.RateLimiter(), "http:gcs"), mux.MiddlewareFunc(webmiddleware.Metadata("Authorization")), + ratelimit.HTTPMiddleware(s.Component.RateLimiter(), "http:gcs"), validateAndFillIDs, ) if s.config.RequireAuth { diff --git a/pkg/gatewayconfigurationserver/v2/server.go b/pkg/gatewayconfigurationserver/v2/server.go index c7b5bb66f2..87acf0f9d7 100644 --- a/pkg/gatewayconfigurationserver/v2/server.go +++ b/pkg/gatewayconfigurationserver/v2/server.go @@ -84,9 +84,9 @@ func (s *Server) RegisterRoutes(server *web.Server) { middleware := []webmiddleware.MiddlewareFunc{ webmiddleware.Namespace("gatewayconfigurationserver/v2"), - ratelimit.HTTPMiddleware(s.component.RateLimiter(), "http:gcs"), rewriteAuthorization, webmiddleware.Metadata("Authorization"), + ratelimit.HTTPMiddleware(s.component.RateLimiter(), "http:gcs"), } router.Handle( diff --git a/pkg/ratelimit/grpc.go b/pkg/ratelimit/grpc.go index 406f2d6f9b..9a9c9d27d2 100644 --- a/pkg/ratelimit/grpc.go +++ b/pkg/ratelimit/grpc.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - "go.thethings.network/lorawan-stack/v3/pkg/auth" clusterauth "go.thethings.network/lorawan-stack/v3/pkg/auth/cluster" "go.thethings.network/lorawan-stack/v3/pkg/log" "go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata" @@ -44,17 +43,6 @@ func grpcIsClusterAuthCall(ctx context.Context) bool { return rpcmetadata.FromIncomingContext(ctx).AuthType == clusterauth.AuthType } -func grpcAuthTokenID(ctx context.Context) string { - if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" { - _, id, _, err := auth.SplitToken(authValue) - if err != nil { - return "unauthenticated" - } - return id - } - return "unauthenticated" -} - // UnaryServerInterceptor returns a gRPC unary server interceptor that rate limits incoming gRPC requests. func UnaryServerInterceptor(limiter Interface) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { diff --git a/pkg/ratelimit/grpc_test.go b/pkg/ratelimit/grpc_test.go index 0e42d8295b..d66d1d7ead 100644 --- a/pkg/ratelimit/grpc_test.go +++ b/pkg/ratelimit/grpc_test.go @@ -53,12 +53,6 @@ func (ss *serverStream) SetHeader(md metadata.MD) error { return nil } -func grpcTokenContext(authTokenID string) context.Context { - return metadata.NewIncomingContext(test.Context(), metadata.Pairs( - "authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID), - )) -} - func grpcClusterContext() context.Context { return metadata.NewIncomingContext(test.Context(), metadata.Pairs( "authorization", fmt.Sprintf("%s %X", clusterauth.AuthType, []byte{0x00, 0x01, 0x02}), @@ -82,7 +76,6 @@ func TestGRPC(t *testing.T) { const ( unaryMethod = "/Service/UnaryMethod" - authTokenID = "my-token-id" streamMethod = "/Service/StreamMethod" ) @@ -157,7 +150,7 @@ func TestGRPC(t *testing.T) { t.Run(tc.name, func(t *testing.T) { intercept := ratelimit.UnaryServerInterceptor(tc.limiter) - ctx := grpcTokenContext(authTokenID) + ctx := tokenContext(authTokenID) if tc.cluster { ctx = grpcClusterContext() } @@ -219,7 +212,7 @@ func TestGRPC(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { intercept := ratelimit.StreamServerInterceptor(tc.limiter) - ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)} + ss := &serverStream{t: t, ctx: tokenContext(authTokenID)} if tc.cluster { ss.ctx = grpcClusterContext() } @@ -238,7 +231,7 @@ func TestGRPC(t *testing.T) { "grpc:stream:up": &mockLimiter{}, } intercept := ratelimit.StreamServerInterceptor(limiter) - ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)} + ss := &serverStream{t: t, ctx: tokenContext(authTokenID)} info := &grpc.StreamServerInfo{FullMethod: streamMethod} keyFromFirstStream := "" diff --git a/pkg/ratelimit/http_test.go b/pkg/ratelimit/http_test.go index 1ab9a6f597..8972b78136 100644 --- a/pkg/ratelimit/http_test.go +++ b/pkg/ratelimit/http_test.go @@ -54,6 +54,42 @@ func TestHTTP(t *testing.T) { a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"}) }) + t.Run("PathTemplate", func(t *testing.T) { + limiter.limit = false + limiter.result = ratelimit.Result{Limit: 10} + + restore := ratelimit.SetPathTemplate(func(r *http.Request) (string, bool) { + return "/path/{id}", true + }) + defer restore() + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, httpRequest("/path/123", "10.10.10.10")) + + a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10") + a.So(rec.Result().StatusCode, should.Equal, http.StatusOK) + + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path/123") + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "10.10.10.10") + a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{"http:test:/path/{id}", class, "http"}) + }) + + t.Run("AuthToken", func(t *testing.T) { + limiter.limit = false + limiter.result = ratelimit.Result{Limit: 10} + + rec := httptest.NewRecorder() + req := httpRequest("/path", "10.10.10.10").WithContext(tokenContext(authTokenID)) + handler.ServeHTTP(rec, req) + + a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10") + a.So(rec.Result().StatusCode, should.Equal, http.StatusOK) + + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path") + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, authTokenID) + a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"}) + }) + t.Run("Limit", func(t *testing.T) { limiter.limit = true rec := httptest.NewRecorder() diff --git a/pkg/ratelimit/resource.go b/pkg/ratelimit/resource.go index 93d6b0cc2e..dfb734d747 100644 --- a/pkg/ratelimit/resource.go +++ b/pkg/ratelimit/resource.go @@ -20,7 +20,10 @@ import ( "net" "net/http" + "github.com/gorilla/mux" + "go.thethings.network/lorawan-stack/v3/pkg/auth" "go.thethings.network/lorawan-stack/v3/pkg/events" + "go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata" "go.thethings.network/lorawan-stack/v3/pkg/ttnpb" "go.thethings.network/lorawan-stack/v3/pkg/unique" ) @@ -44,18 +47,51 @@ type resource struct { func (r *resource) Key() string { return r.key } func (r *resource) Classes() []string { return r.classes } +const unauthenticated = "unauthenticated" + +func authTokenID(ctx context.Context) string { + if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" { + _, id, _, err := auth.SplitToken(authValue) + if err != nil { + return unauthenticated + } + return id + } + return unauthenticated +} + +var pathTemplate = func(r *http.Request) (string, bool) { + route := mux.CurrentRoute(r) + if route == nil { + return "", false + } + pathTemplate, err := route.GetPathTemplate() + if err != nil { + return "", false + } + return pathTemplate, true +} + // httpRequestResource represents an HTTP request. Avoid using directly, use HTTPMiddleware instead. func httpRequestResource(r *http.Request, class string) Resource { + specificClasses := make([]string, 0, 3) + if template, ok := pathTemplate(r); ok { + specificClasses = append(specificClasses, fmt.Sprintf("%s:%s", class, template)) + } + callerInfo := fmt.Sprintf("ip:%s", httpRemoteIP(r)) + if authTokenID := authTokenID(r.Context()); authTokenID != unauthenticated { + callerInfo = fmt.Sprintf("token:%s", authTokenID) + } return &resource{ - key: fmt.Sprintf("%s:ip:%s:url:%s", class, httpRemoteIP(r), r.URL.Path), - classes: []string{class, "http"}, + key: fmt.Sprintf("%s:%s:url:%s", class, callerInfo, r.URL.Path), + classes: append(specificClasses, class, "http"), } } // grpcMethodResource represents a gRPC request. func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resource { key := fmt.Sprintf("grpc:method:%s:%s", fullMethod, grpcEntityFromRequest(ctx, req)) - if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" { + if authTokenID := authTokenID(ctx); authTokenID != unauthenticated { key = fmt.Sprintf("%s:token:%s", key, authTokenID) } return &resource{ @@ -67,7 +103,7 @@ func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resourc // grpcStreamAcceptResource represents a new gRPC server stream. func grpcStreamAcceptResource(ctx context.Context, fullMethod string) Resource { key := fmt.Sprintf("grpc:stream:accept:%s", fullMethod) - if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" { + if authTokenID := authTokenID(ctx); authTokenID != unauthenticated { key = fmt.Sprintf("%s:token:%s", key, authTokenID) } return &resource{ diff --git a/pkg/ratelimit/resource_util_test.go b/pkg/ratelimit/resource_util_test.go new file mode 100644 index 0000000000..a94522a257 --- /dev/null +++ b/pkg/ratelimit/resource_util_test.go @@ -0,0 +1,24 @@ +// Copyright © 2023 The Things Network Foundation, The Things Industries B.V. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import "net/http" + +// SetPathTemplate sets the path template function for HTTP rate limiting. +func SetPathTemplate(f func(*http.Request) (string, bool)) func() { + old := pathTemplate + pathTemplate = f + return func() { pathTemplate = old } +} diff --git a/pkg/ratelimit/util_test.go b/pkg/ratelimit/util_test.go new file mode 100644 index 0000000000..7aaf76f6f4 --- /dev/null +++ b/pkg/ratelimit/util_test.go @@ -0,0 +1,31 @@ +// Copyright © 2023 The Things Network Foundation, The Things Industries B.V. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit_test + +import ( + "context" + "fmt" + + "go.thethings.network/lorawan-stack/v3/pkg/util/test" + "google.golang.org/grpc/metadata" +) + +const authTokenID = "my-token-id" + +func tokenContext(authTokenID string) context.Context { + return metadata.NewIncomingContext(test.Context(), metadata.Pairs( + "authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID), + )) +}