diff --git a/CHANGELOG.md b/CHANGELOG.md index 57bbd971035..12c70864039 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ For details about compatibility between different releases, see the **Commitment - The `http.client.transport.compression` experimental flag. It controls whether the HTTP clients used by the stack support gzip and zstd decompression of server responses. It is enabled by default. - The `http.server.transport.compression` experimental flag. It controls whether the HTTP servers used by the stack support gzip compression of the server response. It is enabled by default. - Rate limiting classes for individual HTTP paths. +- Rate limiting keys for HTTP endpoints now contain the caller API key ID when available. The caller IP is still available as a fallback. ### Changed diff --git a/pkg/ratelimit/grpc.go b/pkg/ratelimit/grpc.go index 406f2d6f9bb..9a9c9d27d29 100644 --- a/pkg/ratelimit/grpc.go +++ b/pkg/ratelimit/grpc.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - "go.thethings.network/lorawan-stack/v3/pkg/auth" clusterauth "go.thethings.network/lorawan-stack/v3/pkg/auth/cluster" "go.thethings.network/lorawan-stack/v3/pkg/log" "go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata" @@ -44,17 +43,6 @@ func grpcIsClusterAuthCall(ctx context.Context) bool { return rpcmetadata.FromIncomingContext(ctx).AuthType == clusterauth.AuthType } -func grpcAuthTokenID(ctx context.Context) string { - if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" { - _, id, _, err := auth.SplitToken(authValue) - if err != nil { - return "unauthenticated" - } - return id - } - return "unauthenticated" -} - // UnaryServerInterceptor returns a gRPC unary server interceptor that rate limits incoming gRPC requests. func UnaryServerInterceptor(limiter Interface) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { diff --git a/pkg/ratelimit/grpc_test.go b/pkg/ratelimit/grpc_test.go index 0e42d8295be..d66d1d7ead0 100644 --- a/pkg/ratelimit/grpc_test.go +++ b/pkg/ratelimit/grpc_test.go @@ -53,12 +53,6 @@ func (ss *serverStream) SetHeader(md metadata.MD) error { return nil } -func grpcTokenContext(authTokenID string) context.Context { - return metadata.NewIncomingContext(test.Context(), metadata.Pairs( - "authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID), - )) -} - func grpcClusterContext() context.Context { return metadata.NewIncomingContext(test.Context(), metadata.Pairs( "authorization", fmt.Sprintf("%s %X", clusterauth.AuthType, []byte{0x00, 0x01, 0x02}), @@ -82,7 +76,6 @@ func TestGRPC(t *testing.T) { const ( unaryMethod = "/Service/UnaryMethod" - authTokenID = "my-token-id" streamMethod = "/Service/StreamMethod" ) @@ -157,7 +150,7 @@ func TestGRPC(t *testing.T) { t.Run(tc.name, func(t *testing.T) { intercept := ratelimit.UnaryServerInterceptor(tc.limiter) - ctx := grpcTokenContext(authTokenID) + ctx := tokenContext(authTokenID) if tc.cluster { ctx = grpcClusterContext() } @@ -219,7 +212,7 @@ func TestGRPC(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { intercept := ratelimit.StreamServerInterceptor(tc.limiter) - ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)} + ss := &serverStream{t: t, ctx: tokenContext(authTokenID)} if tc.cluster { ss.ctx = grpcClusterContext() } @@ -238,7 +231,7 @@ func TestGRPC(t *testing.T) { "grpc:stream:up": &mockLimiter{}, } intercept := ratelimit.StreamServerInterceptor(limiter) - ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)} + ss := &serverStream{t: t, ctx: tokenContext(authTokenID)} info := &grpc.StreamServerInfo{FullMethod: streamMethod} keyFromFirstStream := "" diff --git a/pkg/ratelimit/http_test.go b/pkg/ratelimit/http_test.go index 58bca1bea35..8972b781365 100644 --- a/pkg/ratelimit/http_test.go +++ b/pkg/ratelimit/http_test.go @@ -74,6 +74,22 @@ func TestHTTP(t *testing.T) { a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{"http:test:/path/{id}", class, "http"}) }) + t.Run("AuthToken", func(t *testing.T) { + limiter.limit = false + limiter.result = ratelimit.Result{Limit: 10} + + rec := httptest.NewRecorder() + req := httpRequest("/path", "10.10.10.10").WithContext(tokenContext(authTokenID)) + handler.ServeHTTP(rec, req) + + a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10") + a.So(rec.Result().StatusCode, should.Equal, http.StatusOK) + + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path") + a.So(limiter.calledWithResource.Key(), should.ContainSubstring, authTokenID) + a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"}) + }) + t.Run("Limit", func(t *testing.T) { limiter.limit = true rec := httptest.NewRecorder() diff --git a/pkg/ratelimit/resource.go b/pkg/ratelimit/resource.go index 2285c141299..b72d3bcd6ba 100644 --- a/pkg/ratelimit/resource.go +++ b/pkg/ratelimit/resource.go @@ -78,8 +78,12 @@ func httpRequestResource(r *http.Request, class string) Resource { if template, ok := pathTemplate(r); ok { specificClasses = append(specificClasses, fmt.Sprintf("%s:%s", class, template)) } + authKey := fmt.Sprintf("ip:%s", httpRemoteIP(r)) + if authTokenID := authTokenID(r.Context()); authTokenID != unauthenticated { + authKey = fmt.Sprintf("token:%s", authTokenID) + } return &resource{ - key: fmt.Sprintf("%s:ip:%s:url:%s", class, httpRemoteIP(r), r.URL.Path), + key: fmt.Sprintf("%s:%s:url:%s", class, authKey, r.URL.Path), classes: append(specificClasses, class, "http"), } } @@ -87,7 +91,7 @@ func httpRequestResource(r *http.Request, class string) Resource { // grpcMethodResource represents a gRPC request. func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resource { key := fmt.Sprintf("grpc:method:%s:%s", fullMethod, grpcEntityFromRequest(ctx, req)) - if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" { + if authTokenID := authTokenID(ctx); authTokenID != unauthenticated { key = fmt.Sprintf("%s:token:%s", key, authTokenID) } return &resource{ @@ -99,7 +103,7 @@ func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resourc // grpcStreamAcceptResource represents a new gRPC server stream. func grpcStreamAcceptResource(ctx context.Context, fullMethod string) Resource { key := fmt.Sprintf("grpc:stream:accept:%s", fullMethod) - if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" { + if authTokenID := authTokenID(ctx); authTokenID != unauthenticated { key = fmt.Sprintf("%s:token:%s", key, authTokenID) } return &resource{ diff --git a/pkg/ratelimit/util_test.go b/pkg/ratelimit/util_test.go new file mode 100644 index 00000000000..7aaf76f6f4d --- /dev/null +++ b/pkg/ratelimit/util_test.go @@ -0,0 +1,31 @@ +// Copyright © 2023 The Things Network Foundation, The Things Industries B.V. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit_test + +import ( + "context" + "fmt" + + "go.thethings.network/lorawan-stack/v3/pkg/util/test" + "google.golang.org/grpc/metadata" +) + +const authTokenID = "my-token-id" + +func tokenContext(authTokenID string) context.Context { + return metadata.NewIncomingContext(test.Context(), metadata.Pairs( + "authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID), + )) +}