Skip to content

Commit

Permalink
all: Use caller API key ID in rate limiting keys
Browse files Browse the repository at this point in the history
  • Loading branch information
adriansmares committed Nov 21, 2023
1 parent a3237d3 commit e1da170
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ For details about compatibility between different releases, see the **Commitment
### Added

- Rate limiting classes for individual HTTP paths.
- Rate limiting keys for HTTP endpoints now contain the caller API key ID when available. The caller IP is still available as a fallback.

### Changed

Expand Down
12 changes: 0 additions & 12 deletions pkg/ratelimit/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"fmt"

"go.thethings.network/lorawan-stack/v3/pkg/auth"
clusterauth "go.thethings.network/lorawan-stack/v3/pkg/auth/cluster"
"go.thethings.network/lorawan-stack/v3/pkg/log"
"go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata"
Expand All @@ -44,17 +43,6 @@ func grpcIsClusterAuthCall(ctx context.Context) bool {
return rpcmetadata.FromIncomingContext(ctx).AuthType == clusterauth.AuthType
}

func grpcAuthTokenID(ctx context.Context) string {
if authValue := rpcmetadata.FromIncomingContext(ctx).AuthValue; authValue != "" {
_, id, _, err := auth.SplitToken(authValue)
if err != nil {
return "unauthenticated"
}
return id
}
return "unauthenticated"
}

// UnaryServerInterceptor returns a gRPC unary server interceptor that rate limits incoming gRPC requests.
func UnaryServerInterceptor(limiter Interface) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
Expand Down
13 changes: 3 additions & 10 deletions pkg/ratelimit/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ func (ss *serverStream) SetHeader(md metadata.MD) error {
return nil
}

func grpcTokenContext(authTokenID string) context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID),
))
}

func grpcClusterContext() context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("%s %X", clusterauth.AuthType, []byte{0x00, 0x01, 0x02}),
Expand All @@ -82,7 +76,6 @@ func TestGRPC(t *testing.T) {

const (
unaryMethod = "/Service/UnaryMethod"
authTokenID = "my-token-id"
streamMethod = "/Service/StreamMethod"
)

Expand Down Expand Up @@ -157,7 +150,7 @@ func TestGRPC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
intercept := ratelimit.UnaryServerInterceptor(tc.limiter)

ctx := grpcTokenContext(authTokenID)
ctx := tokenContext(authTokenID)
if tc.cluster {
ctx = grpcClusterContext()
}
Expand Down Expand Up @@ -219,7 +212,7 @@ func TestGRPC(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
intercept := ratelimit.StreamServerInterceptor(tc.limiter)
ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)}
ss := &serverStream{t: t, ctx: tokenContext(authTokenID)}
if tc.cluster {
ss.ctx = grpcClusterContext()
}
Expand All @@ -238,7 +231,7 @@ func TestGRPC(t *testing.T) {
"grpc:stream:up": &mockLimiter{},
}
intercept := ratelimit.StreamServerInterceptor(limiter)
ss := &serverStream{t: t, ctx: grpcTokenContext(authTokenID)}
ss := &serverStream{t: t, ctx: tokenContext(authTokenID)}
info := &grpc.StreamServerInfo{FullMethod: streamMethod}

keyFromFirstStream := ""
Expand Down
16 changes: 16 additions & 0 deletions pkg/ratelimit/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ func TestHTTP(t *testing.T) {
a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{"http:test:/path/{id}", class, "http"})
})

t.Run("AuthToken", func(t *testing.T) {
limiter.limit = false
limiter.result = ratelimit.Result{Limit: 10}

rec := httptest.NewRecorder()
req := httpRequest("/path", "10.10.10.10").WithContext(tokenContext(authTokenID))
handler.ServeHTTP(rec, req)

a.So(rec.Header().Get("x-rate-limit-limit"), should.Equal, "10")
a.So(rec.Result().StatusCode, should.Equal, http.StatusOK)

a.So(limiter.calledWithResource.Key(), should.ContainSubstring, "/path")
a.So(limiter.calledWithResource.Key(), should.ContainSubstring, authTokenID)
a.So(limiter.calledWithResource.Classes(), should.Resemble, []string{class, "http"})
})

t.Run("Limit", func(t *testing.T) {
limiter.limit = true
rec := httptest.NewRecorder()
Expand Down
10 changes: 7 additions & 3 deletions pkg/ratelimit/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,20 @@ func httpRequestResource(r *http.Request, class string) Resource {
if template, ok := pathTemplate(r); ok {
specificClasses = append(specificClasses, fmt.Sprintf("%s:%s", class, template))
}
callerInfo := fmt.Sprintf("ip:%s", httpRemoteIP(r))
if authTokenID := authTokenID(r.Context()); authTokenID != unauthenticated {
callerInfo = fmt.Sprintf("token:%s", authTokenID)
}
return &resource{
key: fmt.Sprintf("%s:ip:%s:url:%s", class, httpRemoteIP(r), r.URL.Path),
key: fmt.Sprintf("%s:%s:url:%s", class, callerInfo, r.URL.Path),
classes: append(specificClasses, class, "http"),
}
}

// grpcMethodResource represents a gRPC request.
func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resource {
key := fmt.Sprintf("grpc:method:%s:%s", fullMethod, grpcEntityFromRequest(ctx, req))
if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" {
if authTokenID := authTokenID(ctx); authTokenID != unauthenticated {
key = fmt.Sprintf("%s:token:%s", key, authTokenID)
}
return &resource{
Expand All @@ -99,7 +103,7 @@ func grpcMethodResource(ctx context.Context, fullMethod string, req any) Resourc
// grpcStreamAcceptResource represents a new gRPC server stream.
func grpcStreamAcceptResource(ctx context.Context, fullMethod string) Resource {
key := fmt.Sprintf("grpc:stream:accept:%s", fullMethod)
if authTokenID := grpcAuthTokenID(ctx); authTokenID != "" {
if authTokenID := authTokenID(ctx); authTokenID != unauthenticated {
key = fmt.Sprintf("%s:token:%s", key, authTokenID)
}
return &resource{
Expand Down
31 changes: 31 additions & 0 deletions pkg/ratelimit/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright © 2023 The Things Network Foundation, The Things Industries B.V.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ratelimit_test

import (
"context"
"fmt"

"go.thethings.network/lorawan-stack/v3/pkg/util/test"
"google.golang.org/grpc/metadata"
)

const authTokenID = "my-token-id"

func tokenContext(authTokenID string) context.Context {
return metadata.NewIncomingContext(test.Context(), metadata.Pairs(
"authorization", fmt.Sprintf("Bearer NNSXS.%s.authTokenKey", authTokenID),
))
}

0 comments on commit e1da170

Please sign in to comment.