From d577ec8085a492f9cf4d0d08f1d5c830fde9b107 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Mon, 21 Jun 2021 21:29:27 +0800 Subject: [PATCH] beater: even more refactoring (#5502) * beater: even more refactoring - rate limiting middleware is now installed for both RUM and backend agent APIs, but only applies for anonymous clients (currently only RUM) - rate limiting middleware now performs an initial Allow check at the request level, for consistent request rate limiting of those endpoints that are rate limited - agent config now restricts "insecure" (RUM) agents on the basis that they are anonymous, rather than being RUM specifically. The list of insecure agent names (those allowed for anonymous auth) is now passed in * make gofmt * beater/api/profile: remove unused field --- beater/api/config/agent/handler.go | 63 +++++------ beater/api/config/agent/handler_test.go | 89 +++++++-------- beater/api/intake/handler.go | 33 ++---- beater/api/intake/handler_test.go | 14 ++- beater/api/mux.go | 86 +++++++++----- beater/api/mux_intake_rum_test.go | 4 +- beater/api/mux_test.go | 7 +- beater/api/profile/handler.go | 20 ++-- beater/api/profile/handler_test.go | 17 +-- .../ratelimit/context.go} | 30 +++-- beater/api/ratelimit/store.go | 17 +-- beater/api/ratelimit/store_test.go | 29 +---- beater/http.go | 14 ++- beater/middleware/authorization_middleware.go | 1 + beater/middleware/rate_limit_middleware.go | 31 +++-- .../middleware/rate_limit_middleware_test.go | 107 ++++++++++++++++++ .../middleware/request_metadata_middleware.go | 46 -------- .../request_metadata_middleware_test.go | 77 ------------- beater/request/context.go | 36 ++---- beater/request/context_test.go | 11 +- beater/server.go | 28 +++-- beater/tracing.go | 14 ++- systemtest/rum_test.go | 7 +- 23 files changed, 403 insertions(+), 378 deletions(-) rename beater/{middleware/rum_middleware.go => api/ratelimit/context.go} (52%) create mode 100644 beater/middleware/rate_limit_middleware_test.go delete mode 100644 beater/middleware/request_metadata_middleware.go delete mode 100644 beater/middleware/request_metadata_middleware_test.go diff --git a/beater/api/config/agent/handler.go b/beater/api/config/agent/handler.go index 04086dcc9d2..f77a39f24fa 100644 --- a/beater/api/config/agent/handler.go +++ b/beater/api/config/agent/handler.go @@ -52,18 +52,21 @@ var ( registry = monitoring.Default.NewRegistry("apm-server.acm") errCacheControl = fmt.Sprintf("max-age=%v, must-revalidate", errMaxAgeDuration.Seconds()) - - // rumAgents keywords (new and old) - rumAgents = []string{"rum-js", "js-base"} ) type handler struct { f agentcfg.Fetcher + allowAnonymousAgents []string cacheControl, defaultServiceEnvironment string } -func NewHandler(f agentcfg.Fetcher, config config.KibanaAgentConfig, defaultServiceEnvironment string) request.Handler { +func NewHandler( + f agentcfg.Fetcher, + config config.KibanaAgentConfig, + defaultServiceEnvironment string, + allowAnonymousAgents []string, +) request.Handler { if f == nil { panic("fetcher must not be nil") } @@ -72,6 +75,7 @@ func NewHandler(f agentcfg.Fetcher, config config.KibanaAgentConfig, defaultServ f: f, cacheControl: cacheControl, defaultServiceEnvironment: defaultServiceEnvironment, + allowAnonymousAgents: allowAnonymousAgents, } return h.Handle @@ -83,13 +87,6 @@ func (h *handler) Handle(c *request.Context) { // error handling c.Header().Set(headers.CacheControl, errCacheControl) - ok := c.RateLimiter == nil || c.RateLimiter.Allow() - if !ok { - c.Result.SetDefault(request.IDResponseErrorsRateLimit) - c.Write() - return - } - query, queryErr := buildQuery(c) if queryErr != nil { extractQueryError(c, queryErr) @@ -100,26 +97,29 @@ func (h *handler) Handle(c *request.Context) { query.Service.Environment = h.defaultServiceEnvironment } - if !c.AuthResult.Anonymous { - // The exact agent is not always known for anonymous clients, so we do not - // issue a secondary authorization check for them. Instead, we issue the - // request and filter the results using query.InsecureAgents. - authResource := authorization.Resource{ServiceName: query.Service.Name} - if result, err := authorization.AuthorizedFor(c.Request.Context(), authResource); err != nil { - c.Result.SetDefault(request.IDResponseErrorsServiceUnavailable) - c.Result.Err = err - c.Write() - return - } else if !result.Authorized { - id := request.IDResponseErrorsUnauthorized - status := request.MapResultIDToStatus[id] - if result.Reason != "" { - status.Keyword = result.Reason - } - c.Result.Set(id, status.Code, status.Keyword, nil, nil) - c.Write() - return + // Only service, and not agent, is known for config queries. + // For anonymous/untrusted agents, we filter the results using + // query.InsecureAgents below. + authResource := authorization.Resource{ServiceName: query.Service.Name} + authResult, err := authorization.AuthorizedFor(c.Request.Context(), authResource) + if err != nil { + c.Result.SetDefault(request.IDResponseErrorsServiceUnavailable) + c.Result.Err = err + c.Write() + return + } + if !authResult.Authorized { + id := request.IDResponseErrorsUnauthorized + status := request.MapResultIDToStatus[id] + if authResult.Reason != "" { + status.Keyword = authResult.Reason } + c.Result.Set(id, status.Code, status.Keyword, nil, nil) + c.Write() + return + } + if authResult.Anonymous { + query.InsecureAgents = h.allowAnonymousAgents } result, err := h.f.Fetch(c.Request.Context(), query) @@ -184,9 +184,6 @@ func buildQuery(c *request.Context) (agentcfg.Query, error) { return query, errors.New(agentcfg.ServiceName + " is required") } - if c.IsRum { - query.InsecureAgents = rumAgents - } query.Etag = ifNoneMatch(c) return query, nil } diff --git a/beater/api/config/agent/handler_test.go b/beater/api/config/agent/handler_test.go index 7dcbc1dc079..7ab065f84d3 100644 --- a/beater/api/config/agent/handler_test.go +++ b/beater/api/config/agent/handler_test.go @@ -32,7 +32,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.elastic.co/apm/apmtest" - "golang.org/x/time/rate" "github.com/elastic/beats/v7/libbeat/common" libkibana "github.com/elastic/beats/v7/libbeat/kibana" @@ -173,18 +172,12 @@ func TestAgentConfigHandler(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: 4 * time.Second}} for _, tc := range testcases { f := agentcfg.NewKibanaFetcher(tc.kbClient, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) r := httptest.NewRequest(tc.method, target(tc.queryParams), nil) for k, v := range tc.requestHeader { r.Header.Set(k, v) } ctx, w := newRequestContext(r) - ctx.AuthResult.Authorized = true - ctx.Request = withAuthorization(ctx.Request, - authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { - return authorization.Result{Authorized: true}, nil - }), - ) h(ctx) require.Equal(t, tc.respStatus, w.Code) @@ -202,32 +195,46 @@ func TestAgentConfigHandlerAnonymousAccess(t *testing.T) { kbClient := kibanatest.MockKibana(http.StatusUnauthorized, m{"error": "Unauthorized"}, mockVersion, true) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kbClient, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) for _, tc := range []struct { - anonymous bool - response string + anonymous bool + response string + authResource *authorization.Resource }{{ - anonymous: false, - response: `{"error":"APM Server is not authorized to query Kibana. Please configure apm-server.kibana.username and apm-server.kibana.password, and ensure the user has the necessary privileges."}`, + anonymous: false, + response: `{"error":"APM Server is not authorized to query Kibana. Please configure apm-server.kibana.username and apm-server.kibana.password, and ensure the user has the necessary privileges."}`, + authResource: &authorization.Resource{ServiceName: "opbeans"}, }, { - anonymous: true, - response: `{"error":"Unauthorized"}`, + anonymous: true, + response: `{"error":"Unauthorized"}`, + authResource: &authorization.Resource{ServiceName: "opbeans"}, }} { r := httptest.NewRequest(http.MethodGet, target(map[string]string{"service.name": "opbeans"}), nil) - ctx, w := newRequestContext(r) - ctx.AuthResult.Authorized = true - ctx.AuthResult.Anonymous = tc.anonymous - ctx.Request = withAuthorization(ctx.Request, authorization.AnonymousAuth{}) - h(ctx) + c, w := newRequestContext(r) + c.AuthResult.Authorized = true + c.AuthResult.Anonymous = tc.anonymous + + var requestedResource *authorization.Resource + c.Request = withAuthorization(c.Request, + authorizedForFunc(func(ctx context.Context, resource authorization.Resource) (authorization.Result, error) { + if requestedResource != nil { + panic("expected only one AuthorizedFor request") + } + requestedResource = &resource + return c.AuthResult, nil + }), + ) + h(c) assert.Equal(t, tc.response+"\n", w.Body.String()) + assert.Equal(t, tc.authResource, requestedResource) } } func TestAgentConfigHandlerAuthorizedForService(t *testing.T) { cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(nil, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) r := httptest.NewRequest(http.MethodGet, target(map[string]string{"service.name": "opbeans"}), nil) ctx, w := newRequestContext(r) @@ -249,7 +256,7 @@ func TestAgentConfigHandlerAuthorizedForService(t *testing.T) { func TestAgentConfigHandler_NoKibanaClient(t *testing.T) { cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(nil, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) w := sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{ "service": m{"name": "opbeans-node"}}))) @@ -268,7 +275,7 @@ func TestAgentConfigHandler_PostOk(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) w := sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{ "service": m{"name": "opbeans-node"}}))) @@ -289,7 +296,7 @@ func TestAgentConfigHandler_DefaultServiceEnvironment(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "default") + h := NewHandler(f, cfg, "default", nil) sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{"service": m{"name": "opbeans-node", "environment": "specified"}}))) sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{"service": m{"name": "opbeans-node"}}))) @@ -306,8 +313,6 @@ func TestAgentConfigRum(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ "service": m{"name": "opbeans"}})) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) var actual map[string]string json.Unmarshal(w.Body.Bytes(), &actual) @@ -320,8 +325,6 @@ func TestAgentConfigRumEtag(t *testing.T) { h := getHandler("rum-js") r := httptest.NewRequest(http.MethodGet, "/rum?ifnonematch=123&service.name=opbeans", nil) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) assert.Equal(t, http.StatusNotModified, w.Code, w.Body.String()) } @@ -333,7 +336,7 @@ func TestAgentConfigNotRum(t *testing.T) { ctx, w := newRequestContext(r) ctx.Request = withAuthorization(ctx.Request, authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { - return authorization.Result{Authorized: true}, nil + return authorization.Result{Authorized: true, Anonymous: false}, nil }), ) h(ctx) @@ -348,8 +351,6 @@ func TestAgentConfigNoLeak(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ "service": m{"name": "opbeans"}})) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) var actual map[string]string json.Unmarshal(w.Body.Bytes(), &actual) @@ -357,21 +358,6 @@ func TestAgentConfigNoLeak(t *testing.T) { assert.Equal(t, map[string]string{}, actual) } -func TestAgentConfigRateLimit(t *testing.T) { - h := getHandler("rum-js") - r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ - "service": m{"name": "opbeans"}})) - ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.RateLimiter = rate.NewLimiter(rate.Limit(0), 0) - ctx.AuthResult.Anonymous = true - h(ctx) - var actual map[string]string - json.Unmarshal(w.Body.Bytes(), &actual) - assert.Equal(t, http.StatusTooManyRequests, w.Code, w.Body.String()) - assert.Equal(t, map[string]string{"error": "too many requests"}, actual) -} - func getHandler(agent string) request.Handler { kb := kibanatest.MockKibana(http.StatusOK, m{ "_id": "1", @@ -386,7 +372,7 @@ func getHandler(agent string) request.Handler { }, mockVersion, true) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - return NewHandler(f, cfg, "") + return NewHandler(f, cfg, "", []string{"rum-js"}) } func TestIfNoneMatch(t *testing.T) { @@ -412,7 +398,7 @@ func TestAgentConfigTraceContext(t *testing.T) { client := kibana.NewConnectingClient(&kibanaCfg) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: 5 * time.Minute}} f := agentcfg.NewKibanaFetcher(client, cfg.Cache.Expiration) - handler := NewHandler(f, cfg, "default") + handler := NewHandler(f, cfg, "default", nil) _, spans, _ := apmtest.WithTransaction(func(ctx context.Context) { // When the handler is called with a context containing // a transaction, the underlying Kibana query should create a span @@ -439,6 +425,7 @@ func newRequestContext(r *http.Request) (*request.Context, *httptest.ResponseRec w := httptest.NewRecorder() ctx := request.NewContext() ctx.Reset(w, r) + ctx.Request = withAnonymousAuthorization(ctx.Request) return ctx, w } @@ -471,6 +458,12 @@ func (c *recordingKibanaClient) Send(ctx context.Context, method string, path st return c.Client.Send(ctx, method, path, params, header, body) } +func withAnonymousAuthorization(req *http.Request) *http.Request { + return withAuthorization(req, authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { + return authorization.Result{Authorized: true, Anonymous: true}, nil + })) +} + func withAuthorization(req *http.Request, auth authorization.Authorization) *http.Request { return req.WithContext(authorization.ContextWithAuthorization(req.Context(), auth)) } diff --git a/beater/api/intake/handler.go b/beater/api/intake/handler.go index 2bfb47b7037..3d9a4d54937 100644 --- a/beater/api/intake/handler.go +++ b/beater/api/intake/handler.go @@ -30,6 +30,7 @@ import ( "github.com/elastic/beats/v7/libbeat/monitoring" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/request" "github.com/elastic/apm-server/decoder" @@ -52,7 +53,6 @@ var ( errMethodNotAllowed = errors.New("only POST requests are supported") errServerShuttingDown = errors.New("server is shutting down") errInvalidContentType = errors.New("invalid content type") - errRateLimitExceeded = errors.New("rate limit exceeded") ) // StreamHandler is an interface for handling an Elastic APM agent ND-JSON event @@ -68,27 +68,23 @@ type StreamHandler interface { ) error } +// RequestMetadataFunc is a function type supplied to Handler for extracting +// metadata from the request. This is used for conditionally injecting the +// source IP address as `client.ip` for RUM. +type RequestMetadataFunc func(*request.Context) model.Metadata + // Handler returns a request.Handler for managing intake requests for backend and rum events. -func Handler(handler StreamHandler, batchProcessor model.BatchProcessor) request.Handler { +func Handler(handler StreamHandler, requestMetadataFunc RequestMetadataFunc, batchProcessor model.BatchProcessor) request.Handler { return func(c *request.Context) { if err := validateRequest(c); err != nil { writeError(c, err) return } - if c.RateLimiter != nil { - // Call Allow once for each stream to avoid burning CPU before we - // begin reading for the first time. This prevents clients from - // repeatedly connecting and sending < batchSize events and - // disconnecting before being rate limited. - if !c.RateLimiter.Allow() { - writeError(c, errRateLimitExceeded) - return - } - + if limiter, ok := ratelimit.FromContext(c.Request.Context()); ok { // Apply rate limiting after reading but before processing any events. batchProcessor = modelprocessor.Chained{ - rateLimitBatchProcessor(c.RateLimiter, batchSize), + rateLimitBatchProcessor(limiter, batchSize), batchProcessor, } } @@ -99,12 +95,7 @@ func Handler(handler StreamHandler, batchProcessor model.BatchProcessor) request return } - metadata := model.Metadata{ - UserAgent: model.UserAgent{Original: c.RequestMetadata.UserAgent}, - Client: model.Client{IP: c.RequestMetadata.ClientIP}, - System: model.System{IP: c.RequestMetadata.SystemIP}, - } - + metadata := requestMetadataFunc(c) var result stream.Result if err := handler.HandleStream( c.Request.Context(), @@ -132,7 +123,7 @@ func rateLimitBatch(ctx context.Context, limiter *rate.Limiter, batchSize int) e ctx, cancel := context.WithTimeout(ctx, rateLimitTimeout) defer cancel() if err := limiter.WaitN(ctx, batchSize); err != nil { - return errRateLimitExceeded + return ratelimit.ErrRateLimitExceeded } return nil } @@ -191,7 +182,7 @@ func writeStreamResult(c *request.Context, sr *stream.Result) { errID = request.IDResponseErrorsMethodNotAllowed case errors.Is(err, errInvalidContentType): errID = request.IDResponseErrorsValidate - case errors.Is(err, errRateLimitExceeded): + case errors.Is(err, ratelimit.ErrRateLimitExceeded): errID = request.IDResponseErrorsRateLimit } } diff --git a/beater/api/intake/handler_test.go b/beater/api/intake/handler_test.go index 5a4b7cf1b4d..ed4074520f4 100644 --- a/beater/api/intake/handler_test.go +++ b/beater/api/intake/handler_test.go @@ -34,6 +34,7 @@ import ( "golang.org/x/time/rate" "github.com/elastic/apm-server/approvaltest" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/request" @@ -136,7 +137,7 @@ func TestIntakeHandler(t *testing.T) { tc.setup(t) // call handler - h := Handler(tc.processor, tc.batchProcessor) + h := Handler(tc.processor, emptyRequestMetadata, tc.batchProcessor) h(tc.c) require.Equal(t, string(tc.id), string(tc.c.Result.ID)) @@ -190,12 +191,15 @@ func TestRateLimiting(t *testing.T) { var tc testcaseIntakeHandler tc.path = "ratelimit.ndjson" tc.setup(t) - tc.c.RateLimiter = test.limiter + + tc.c.Request = tc.c.Request.WithContext( + ratelimit.ContextWithLimiter(tc.c.Request.Context(), test.limiter), + ) if test.preconsumed > 0 { test.limiter.AllowN(time.Now(), test.preconsumed) } - h := Handler(tc.processor, tc.batchProcessor) + h := Handler(tc.processor, emptyRequestMetadata, tc.batchProcessor) h(tc.c) if test.expectLimited { @@ -275,3 +279,7 @@ func compressedRequest(t *testing.T, compressionType string, compressPayload boo req.Header.Set(headers.ContentEncoding, compressionType) return req } + +func emptyRequestMetadata(*request.Context) model.Metadata { + return model.Metadata{} +} diff --git a/beater/api/mux.go b/beater/api/mux.go index 55e9c1ee73f..436a96d515c 100644 --- a/beater/api/mux.go +++ b/beater/api/mux.go @@ -31,6 +31,7 @@ import ( "github.com/elastic/apm-server/beater/api/config/agent" "github.com/elastic/apm-server/beater/api/intake" "github.com/elastic/apm-server/beater/api/profile" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/api/root" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/config" @@ -68,6 +69,13 @@ const ( IntakeRUMV3Path = "/intake/v3/rum/events" ) +var ( + // rumAgents holds the current and previous agent names for the + // RUM JavaScript agent. This is used for restricting which config + // is supplied to anonymous agents. + rumAgents = []string{"rum-js", "js-base"} +) + // NewMux registers apm handlers to paths building up the APM Server API. func NewMux( beatInfo beat.Info, @@ -75,6 +83,7 @@ func NewMux( report publish.Reporter, batchProcessor model.BatchProcessor, fetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, ) (*http.ServeMux, error) { pool := request.NewContextPool() mux := http.NewServeMux() @@ -91,6 +100,7 @@ func NewMux( authBuilder: auth, reporter: report, batchProcessor: batchProcessor, + ratelimitStore: ratelimitStore, } type route struct { @@ -140,33 +150,46 @@ type routeBuilder struct { authBuilder *authorization.Builder reporter publish.Reporter batchProcessor model.BatchProcessor + ratelimitStore *ratelimit.Store } func (r *routeBuilder) profileHandler() (request.Handler, error) { - h := profile.Handler(r.batchProcessor) + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = backendRequestMetadata + } + h := profile.Handler(requestMetadataFunc, r.batchProcessor) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeEventWrite.Action) - return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, profile.MonitoringMap)...) + return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, r.ratelimitStore, profile.MonitoringMap)...) } func (r *routeBuilder) backendIntakeHandler() (request.Handler, error) { - h := intake.Handler(stream.BackendProcessor(r.cfg), r.batchProcessor) + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = backendRequestMetadata + } + h := intake.Handler(stream.BackendProcessor(r.cfg), requestMetadataFunc, r.batchProcessor) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeEventWrite.Action) - return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, intake.MonitoringMap)...) + return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, r.ratelimitStore, intake.MonitoringMap)...) } func (r *routeBuilder) rumIntakeHandler(newProcessor func(*config.Config) *stream.Processor) func() (request.Handler, error) { + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = rumRequestMetadata + } return func() (request.Handler, error) { batchProcessor := r.batchProcessor batchProcessor = batchProcessorWithAllowedServiceNames(batchProcessor, r.cfg.RumConfig.AllowServiceNames) - h := intake.Handler(newProcessor(r.cfg), batchProcessor) - return middleware.Wrap(h, rumMiddleware(r.cfg, nil, intake.MonitoringMap)...) + h := intake.Handler(newProcessor(r.cfg), requestMetadataFunc, batchProcessor) + return middleware.Wrap(h, rumMiddleware(r.cfg, nil, r.ratelimitStore, intake.MonitoringMap)...) } } func (r *routeBuilder) sourcemapHandler() (request.Handler, error) { h := sourcemap.Handler(r.reporter) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeSourcemapWrite.Action) - return middleware.Wrap(h, sourcemapMiddleware(r.cfg, authHandler)...) + return middleware.Wrap(h, sourcemapMiddleware(r.cfg, authHandler, r.ratelimitStore)...) } func (r *routeBuilder) rootHandler() (request.Handler, error) { @@ -177,26 +200,27 @@ func (r *routeBuilder) rootHandler() (request.Handler, error) { func (r *routeBuilder) backendAgentConfigHandler(f agentcfg.Fetcher) func() (request.Handler, error) { return func() (request.Handler, error) { authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeAgentConfigRead.Action) - return agentConfigHandler(r.cfg, authHandler, backendMiddleware, f) + return agentConfigHandler(r.cfg, authHandler, r.ratelimitStore, backendMiddleware, f) } } func (r *routeBuilder) rumAgentConfigHandler(f agentcfg.Fetcher) func() (request.Handler, error) { return func() (request.Handler, error) { - return agentConfigHandler(r.cfg, nil, rumMiddleware, f) + return agentConfigHandler(r.cfg, nil, r.ratelimitStore, rumMiddleware, f) } } -type middlewareFunc func(*config.Config, *authorization.Handler, map[request.ResultID]*monitoring.Int) []middleware.Middleware +type middlewareFunc func(*config.Config, *authorization.Handler, *ratelimit.Store, map[request.ResultID]*monitoring.Int) []middleware.Middleware func agentConfigHandler( cfg *config.Config, authHandler *authorization.Handler, + ratelimitStore *ratelimit.Store, middlewareFunc middlewareFunc, f agentcfg.Fetcher, ) (request.Handler, error) { - mw := middlewareFunc(cfg, authHandler, agent.MonitoringMap) - h := agent.NewHandler(f, cfg.KibanaAgentConfig, cfg.DefaultServiceEnvironment) + mw := middlewareFunc(cfg, authHandler, ratelimitStore, agent.MonitoringMap) + h := agent.NewHandler(f, cfg.KibanaAgentConfig, cfg.DefaultServiceEnvironment, rumAgents) if !cfg.Kibana.Enabled && cfg.AgentConfigs == nil { msg := "Agent remote configuration is disabled. " + @@ -219,37 +243,30 @@ func apmMiddleware(m map[request.ResultID]*monitoring.Int) []middleware.Middlewa } } -func backendMiddleware(cfg *config.Config, auth *authorization.Handler, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { +func backendMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { backendMiddleware := append(apmMiddleware(m), middleware.ResponseHeadersMiddleware(cfg.ResponseHeaders), middleware.AuthorizationMiddleware(auth, true), + middleware.AnonymousRateLimitMiddleware(ratelimitStore), ) - if cfg.AugmentEnabled { - backendMiddleware = append(backendMiddleware, middleware.SystemMetadataMiddleware()) - } return backendMiddleware } -func rumMiddleware(cfg *config.Config, _ *authorization.Handler, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { +func rumMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { msg := "RUM endpoint is disabled. " + "Configure the `apm-server.rum` section in apm-server.yml to enable ingestion of RUM events. " + "If you are not using the RUM agent, you can safely ignore this error." rumMiddleware := append(apmMiddleware(m), middleware.ResponseHeadersMiddleware(cfg.ResponseHeaders), middleware.ResponseHeadersMiddleware(cfg.RumConfig.ResponseHeaders), - middleware.SetRumFlagMiddleware(), - middleware.SetIPRateLimitMiddleware(cfg.RumConfig.EventRate), middleware.CORSMiddleware(cfg.RumConfig.AllowOrigins, cfg.RumConfig.AllowHeaders), middleware.AnonymousAuthorizationMiddleware(), - middleware.KillSwitchMiddleware(cfg.RumConfig.Enabled, msg), + middleware.AnonymousRateLimitMiddleware(ratelimitStore), ) - if cfg.AugmentEnabled { - rumMiddleware = append(rumMiddleware, middleware.UserMetadataMiddleware()) - } - return rumMiddleware + return append(rumMiddleware, middleware.KillSwitchMiddleware(cfg.RumConfig.Enabled, msg)) } -func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler) []middleware.Middleware { +func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store) []middleware.Middleware { msg := "Sourcemap upload endpoint is disabled. " + "Configure the `apm-server.rum` section in apm-server.yml to enable sourcemap uploads. " + "If you are not using the RUM agent, you can safely ignore this error." @@ -257,8 +274,8 @@ func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler) []midd msg = "When APM Server is managed by Fleet, Sourcemaps must be uploaded directly to Elasticsearch." } enabled := cfg.RumConfig.Enabled && cfg.RumConfig.SourceMapping.Enabled && !cfg.DataStreams.Enabled - return append(backendMiddleware(cfg, auth, sourcemap.MonitoringMap), - middleware.KillSwitchMiddleware(enabled, msg)) + backendMiddleware := backendMiddleware(cfg, auth, ratelimitStore, sourcemap.MonitoringMap) + return append(backendMiddleware, middleware.KillSwitchMiddleware(enabled, msg)) } func rootMiddleware(cfg *config.Config, auth *authorization.Handler) []middleware.Middleware { @@ -288,3 +305,18 @@ func batchProcessorWithAllowedServiceNames(p model.BatchProcessor, allowedServic } return modelprocessor.Chained{restrictServiceName, p} } + +func emptyRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{} +} + +func backendRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{System: model.System{IP: c.SourceIP}} +} + +func rumRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{ + Client: model.Client{IP: c.SourceIP}, + UserAgent: model.UserAgent{Original: c.UserAgent}, + } +} diff --git a/beater/api/mux_intake_rum_test.go b/beater/api/mux_intake_rum_test.go index 5bf99c1ee44..2f5f728a5b2 100644 --- a/beater/api/mux_intake_rum_test.go +++ b/beater/api/mux_intake_rum_test.go @@ -29,6 +29,7 @@ import ( "github.com/elastic/apm-server/approvaltest" "github.com/elastic/apm-server/beater/api/intake" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/middleware" @@ -36,6 +37,7 @@ import ( ) func TestOPTIONS(t *testing.T) { + ratelimitStore, _ := ratelimit.NewStore(1, 1, 1) requestTaken := make(chan struct{}, 1) done := make(chan struct{}, 1) @@ -46,7 +48,7 @@ func TestOPTIONS(t *testing.T) { requestTaken <- struct{}{} <-done }, - rumMiddleware(cfg, nil, intake.MonitoringMap)...) + rumMiddleware(cfg, nil, ratelimitStore, intake.MonitoringMap)...) // use this to block the single allowed concurrent requests go func() { diff --git a/beater/api/mux_test.go b/beater/api/mux_test.go index 8ce997c1eda..7b6be8969e8 100644 --- a/beater/api/mux_test.go +++ b/beater/api/mux_test.go @@ -28,6 +28,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/approvaltest" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/beatertest" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/request" @@ -76,7 +77,8 @@ func requestToMuxerWithHeaderAndQueryString( func requestToMuxer(cfg *config.Config, r *http.Request) (*httptest.ResponseRecorder, error) { nopReporter := func(context.Context, publish.PendingReq) error { return nil } nopBatchProcessor := model.ProcessBatchFunc(func(context.Context, *model.Batch) error { return nil }) - mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg)) + ratelimitStore, _ := ratelimit.NewStore(1000, 1000, 1000) + mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg), ratelimitStore) if err != nil { return nil, err } @@ -111,7 +113,8 @@ func testMonitoringMiddleware(t *testing.T, urlPath string, monitoringMap map[re func newTestMux(t *testing.T, cfg *config.Config) http.Handler { nopReporter := func(context.Context, publish.PendingReq) error { return nil } nopBatchProcessor := model.ProcessBatchFunc(func(context.Context, *model.Batch) error { return nil }) - mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg)) + ratelimitStore, _ := ratelimit.NewStore(1000, 1000, 1000) + mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg), ratelimitStore) require.NoError(t, err) return mux } diff --git a/beater/api/profile/handler.go b/beater/api/profile/handler.go index bf1c10f7617..e1f30da55fd 100644 --- a/beater/api/profile/handler.go +++ b/beater/api/profile/handler.go @@ -55,8 +55,13 @@ const ( profileContentLengthLimit = 10 * 1024 * 1024 ) +// RequestMetadataFunc is a function type supplied to Handler for extracting +// metadata from the request. This is used for conditionally injecting the +// source IP address as `client.ip` for RUM. +type RequestMetadataFunc func(*request.Context) model.Metadata + // Handler returns a request.Handler for managing profile requests. -func Handler(processor model.BatchProcessor) request.Handler { +func Handler(requestMetadataFunc RequestMetadataFunc, processor model.BatchProcessor) request.Handler { handle := func(c *request.Context) (*result, error) { if c.Request.Method != http.MethodPost { return nil, requestError{ @@ -71,14 +76,6 @@ func Handler(processor model.BatchProcessor) request.Handler { } } - ok := c.RateLimiter == nil || c.RateLimiter.Allow() - if !ok { - return nil, requestError{ - id: request.IDResponseErrorsRateLimit, - err: errors.New("rate limit exceeded"), - } - } - var totalLimitRemaining int64 = profileContentLengthLimit var profiles []*pprof_profile.Profile var profileMetadata model.Metadata @@ -104,10 +101,7 @@ func Handler(processor model.BatchProcessor) request.Handler { } r := &decoder.LimitedReader{R: part, N: metadataContentLengthLimit} dec := decoder.NewJSONDecoder(r) - metadata := model.Metadata{ - UserAgent: model.UserAgent{Original: c.RequestMetadata.UserAgent}, - Client: model.Client{IP: c.RequestMetadata.ClientIP}, - System: model.System{IP: c.RequestMetadata.SystemIP}} + metadata := requestMetadataFunc(c) if err := v2.DecodeMetadata(dec, &metadata); err != nil { if r.N < 0 { return nil, requestError{ diff --git a/beater/api/profile/handler_test.go b/beater/api/profile/handler_test.go index 9a388eea4b3..0fa0f80eb52 100644 --- a/beater/api/profile/handler_test.go +++ b/beater/api/profile/handler_test.go @@ -31,7 +31,6 @@ import ( "strings" "testing" - "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/model" "github.com/stretchr/testify/assert" @@ -45,8 +44,6 @@ import ( const pprofContentType = `application/x-protobuf; messageType="perftools.profiles.Profile"` func TestHandler(t *testing.T) { - var rateLimit, err = ratelimit.NewStore(1, 0, 0) - require.NoError(t, err) for name, tc := range map[string]testcaseIntakeHandler{ "MethodNotAllowed": { r: httptest.NewRequest(http.MethodGet, "/", nil), @@ -60,10 +57,6 @@ func TestHandler(t *testing.T) { }(), id: request.IDResponseErrorsValidate, }, - "RateLimitExceeded": { - rateLimit: rateLimit, - id: request.IDResponseErrorsRateLimit, - }, "Closing": { batchProcessor: func(t *testing.T) model.BatchProcessor { return model.ProcessBatchFunc(func(context.Context, *model.Batch) error { @@ -199,10 +192,7 @@ func TestHandler(t *testing.T) { } { t.Run(name, func(t *testing.T) { tc.setup(t) - if tc.rateLimit != nil { - tc.c.RateLimiter = tc.rateLimit.ForIP(&http.Request{}) - } - Handler(tc.batchProcessor(t))(tc.c) + Handler(emptyRequestMetadata, tc.batchProcessor(t))(tc.c) assert.Equal(t, string(tc.id), string(tc.c.Result.ID)) resultStatus := request.MapResultIDToStatus[tc.id] @@ -225,7 +215,6 @@ type testcaseIntakeHandler struct { c *request.Context w *httptest.ResponseRecorder r *http.Request - rateLimit *ratelimit.Store batchProcessor func(t *testing.T) model.BatchProcessor reports int parts []part @@ -299,3 +288,7 @@ func prettyJSON(v interface{}) string { enc.Encode(v) return buf.String() } + +func emptyRequestMetadata(*request.Context) model.Metadata { + return model.Metadata{} +} diff --git a/beater/middleware/rum_middleware.go b/beater/api/ratelimit/context.go similarity index 52% rename from beater/middleware/rum_middleware.go rename to beater/api/ratelimit/context.go index be3f8f33d4d..d17f99597de 100644 --- a/beater/middleware/rum_middleware.go +++ b/beater/api/ratelimit/context.go @@ -15,18 +15,28 @@ // specific language governing permissions and limitations // under the License. -package middleware +package ratelimit import ( - "github.com/elastic/apm-server/beater/request" + "context" + + "github.com/pkg/errors" + "golang.org/x/time/rate" ) -// SetRumFlagMiddleware sets a rum flag in the context -func SetRumFlagMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.IsRum = true - h(c) - }, nil - } +// ErrRateLimitExceeded is returned when the rate limit is exceeded. +var ErrRateLimitExceeded = errors.New("rate limit exceeded") + +type rateLimiterKey struct{} + +// FromContext returns a rate.Limiter if one is contained in ctx, +// and a bool indicating whether one was found. +func FromContext(ctx context.Context) (*rate.Limiter, bool) { + limiter, ok := ctx.Value(rateLimiterKey{}).(*rate.Limiter) + return limiter, ok +} + +// ContextWithLimiter returns a copy of parent associated with limiter. +func ContextWithLimiter(parent context.Context, limiter *rate.Limiter) context.Context { + return context.WithValue(parent, rateLimiterKey{}, limiter) } diff --git a/beater/api/ratelimit/store.go b/beater/api/ratelimit/store.go index b9fac6c8cfd..660a3af76cd 100644 --- a/beater/api/ratelimit/store.go +++ b/beater/api/ratelimit/store.go @@ -18,11 +18,9 @@ package ratelimit import ( - "net/http" + "net" "sync" - "github.com/elastic/apm-server/utility" - "github.com/hashicorp/golang-lru/simplelru" "github.com/pkg/errors" "golang.org/x/time/rate" @@ -63,8 +61,9 @@ func NewStore(size, rateLimit, burstFactor int) (*Store, error) { return &store, nil } -// acquire returns a rate.Limiter instance for the given key -func (s *Store) acquire(key string) *rate.Limiter { +// ForIP returns a rate limiter for the given IP. +func (s *Store) ForIP(ip net.IP) *rate.Limiter { + key := ip.String() // lock get and add action for cache to allow proper eviction handling without // race conditions. @@ -83,11 +82,3 @@ func (s *Store) acquire(key string) *rate.Limiter { } return limiter } - -// ForIP returns a rate limiter for the given request IP -func (s *Store) ForIP(r *http.Request) *rate.Limiter { - if s == nil { - return nil - } - return s.acquire(utility.RemoteAddr(r)) -} diff --git a/beater/api/ratelimit/store_test.go b/beater/api/ratelimit/store_test.go index e6559e089f0..4706e9dbff8 100644 --- a/beater/api/ratelimit/store_test.go +++ b/beater/api/ratelimit/store_test.go @@ -18,7 +18,7 @@ package ratelimit import ( - "net/http" + "net" "testing" "time" @@ -38,7 +38,6 @@ func TestCacheInitFails(t *testing.T) { c, err := NewStore(test.size, test.limit, 3) assert.Error(t, err) assert.Nil(t, c) - assert.Nil(t, c.ForIP(&http.Request{})) } } @@ -50,20 +49,20 @@ func TestCacheEviction(t *testing.T) { require.NoError(t, err) // add new limiter - rlA := store.acquire("a") + rlA := store.ForIP(net.ParseIP("127.0.0.1")) rlA.AllowN(time.Now(), 3) // add new limiter - rlB := store.acquire("b") + rlB := store.ForIP(net.ParseIP("127.0.0.2")) rlB.AllowN(time.Now(), 2) // reuse evicted limiter rlA - rlC := store.acquire("c") + rlC := store.ForIP(net.ParseIP("127.0.0.3")) assert.False(t, rlC.Allow()) assert.Equal(t, rlC, store.evictedLimiter) // reuse evicted limiter rlB - rlD := store.acquire("a") + rlD := store.ForIP(net.ParseIP("127.0.0.1")) assert.True(t, rlD.Allow()) assert.False(t, rlD.Allow()) assert.Equal(t, rlD, store.evictedLimiter) @@ -77,22 +76,6 @@ func TestCacheEviction(t *testing.T) { func TestCacheOk(t *testing.T) { store, err := NewStore(1, 1, 1) require.NoError(t, err) - limiter := store.acquire("a") + limiter := store.ForIP(net.ParseIP("127.0.0.1")) assert.NotNil(t, limiter) } - -func TestRateLimitPerIP(t *testing.T) { - store, err := NewStore(2, 1, 1) - require.NoError(t, err) - - var reqFrom = func(ip string) *http.Request { - r := http.Request{} - r.Header = http.Header{} - r.Header.Set("X-Real-Ip", ip) - return &r - } - assert.True(t, store.ForIP(reqFrom("10.10.10.1")).Allow()) - assert.False(t, store.ForIP(reqFrom("10.10.10.1")).Allow()) - assert.True(t, store.ForIP(reqFrom("10.10.10.2")).Allow()) - assert.False(t, store.ForIP(reqFrom("10.10.10.3")).Allow()) -} diff --git a/beater/http.go b/beater/http.go index 994e4b396ed..d7bcb77bab5 100644 --- a/beater/http.go +++ b/beater/http.go @@ -29,6 +29,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/beater/api" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/model" "github.com/elastic/apm-server/model/modelprocessor" @@ -47,7 +48,16 @@ type httpServer struct { grpcListener net.Listener } -func newHTTPServer(logger *logp.Logger, info beat.Info, cfg *config.Config, tracer *apm.Tracer, reporter publish.Reporter, batchProcessor model.BatchProcessor, f agentcfg.Fetcher) (*httpServer, error) { +func newHTTPServer( + logger *logp.Logger, + info beat.Info, + cfg *config.Config, + tracer *apm.Tracer, + reporter publish.Reporter, + batchProcessor model.BatchProcessor, + agentcfgFetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, +) (*httpServer, error) { // Add a model processor that checks authorization for the agent and service for each event. batchProcessor = modelprocessor.Chained{ @@ -55,7 +65,7 @@ func newHTTPServer(logger *logp.Logger, info beat.Info, cfg *config.Config, trac batchProcessor, } - mux, err := api.NewMux(info, cfg, reporter, batchProcessor, f) + mux, err := api.NewMux(info, cfg, reporter, batchProcessor, agentcfgFetcher, ratelimitStore) if err != nil { return nil, err } diff --git a/beater/middleware/authorization_middleware.go b/beater/middleware/authorization_middleware.go index 2de8adb25f1..724b87881d4 100644 --- a/beater/middleware/authorization_middleware.go +++ b/beater/middleware/authorization_middleware.go @@ -65,6 +65,7 @@ func AnonymousAuthorizationMiddleware() Middleware { return func(h request.Handler) (request.Handler, error) { return func(c *request.Context) { auth := authorization.AnonymousAuth{} + c.AuthResult = authorization.Result{Authorized: true, Anonymous: true} c.Request = c.Request.WithContext(authorization.ContextWithAuthorization(c.Request.Context(), auth)) h(c) }, nil diff --git a/beater/middleware/rate_limit_middleware.go b/beater/middleware/rate_limit_middleware.go index 8b71f65cfe6..cb954432517 100644 --- a/beater/middleware/rate_limit_middleware.go +++ b/beater/middleware/rate_limit_middleware.go @@ -19,20 +19,33 @@ package middleware import ( "github.com/elastic/apm-server/beater/api/ratelimit" - "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/request" ) -const burstMultiplier = 3 - -// SetIPRateLimitMiddleware sets a rate limiter -func SetIPRateLimitMiddleware(cfg config.EventRate) Middleware { - store, err := ratelimit.NewStore(cfg.LruSize, cfg.Limit, burstMultiplier) - +// AnonymousRateLimitMiddleware adds a rate.Limiter to the context of anonymous +// requests, first ensuring the client is allowed to perform a single event and +// responding with 429 Too Many Requests if it is not. +// +// This middleware must be wrapped by AuthorizationMiddleware, as it depends on +// the value of c.AuthResult.Anonymous. +func AnonymousRateLimitMiddleware(store *ratelimit.Store) Middleware { return func(h request.Handler) (request.Handler, error) { return func(c *request.Context) { - c.RateLimiter = store.ForIP(c.Request) + if c.AuthResult.Anonymous { + limiter := store.ForIP(c.SourceIP) + if !limiter.Allow() { + c.Result.SetWithError( + request.IDResponseErrorsRateLimit, + ratelimit.ErrRateLimitExceeded, + ) + c.Write() + return + } + ctx := c.Request.Context() + ctx = ratelimit.ContextWithLimiter(ctx, limiter) + c.Request = c.Request.WithContext(ctx) + } h(c) - }, err + }, nil } } diff --git a/beater/middleware/rate_limit_middleware_test.go b/beater/middleware/rate_limit_middleware_test.go new file mode 100644 index 00000000000..bd6c9d18b89 --- /dev/null +++ b/beater/middleware/rate_limit_middleware_test.go @@ -0,0 +1,107 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/apm-server/beater/api/ratelimit" + "github.com/elastic/apm-server/beater/request" +) + +func TestAnonymousRateLimitMiddleware(t *testing.T) { + type test struct { + burst int + anonymous bool + + expectStatusCode int + expectAllow bool + } + for _, test := range []test{{ + burst: 0, + anonymous: false, + expectStatusCode: http.StatusOK, + }, { + burst: 0, + anonymous: true, + expectStatusCode: http.StatusTooManyRequests, + }, { + burst: 1, + anonymous: true, + expectStatusCode: http.StatusOK, + expectAllow: false, + }, { + burst: 2, + anonymous: true, + expectStatusCode: http.StatusOK, + expectAllow: true, + }} { + store, _ := ratelimit.NewStore(1, 1, test.burst) + middleware := AnonymousRateLimitMiddleware(store) + handler := func(c *request.Context) { + limiter, ok := ratelimit.FromContext(c.Request.Context()) + if test.anonymous { + require.True(t, ok) + assert.Equal(t, test.expectAllow, limiter.Allow()) + } else { + require.False(t, ok) + } + } + wrapped, err := middleware(handler) + require.NoError(t, err) + + c := request.NewContext() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + c.Reset(w, r) + c.AuthResult.Anonymous = test.anonymous + + wrapped(c) + assert.Equal(t, test.expectStatusCode, w.Code) + } +} + +func TestAnonymousRateLimitMiddlewareForIP(t *testing.T) { + store, _ := ratelimit.NewStore(2, 1, 1) + middleware := AnonymousRateLimitMiddleware(store) + handler := func(c *request.Context) {} + wrapped, err := middleware(handler) + require.NoError(t, err) + + requestWithIP := func(ip string) int { + c := request.NewContext() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.RemoteAddr = ip + c.Reset(w, r) + c.AuthResult.Anonymous = true + wrapped(c) + return w.Code + } + assert.Equal(t, http.StatusOK, requestWithIP("10.1.1.1")) + assert.Equal(t, http.StatusTooManyRequests, requestWithIP("10.1.1.1")) + assert.Equal(t, http.StatusOK, requestWithIP("10.1.1.2")) + + // ratelimit.Store size is 2: the 3rd IP reuses an existing (depleted) rate limiter. + assert.Equal(t, http.StatusTooManyRequests, requestWithIP("10.1.1.3")) +} diff --git a/beater/middleware/request_metadata_middleware.go b/beater/middleware/request_metadata_middleware.go deleted file mode 100644 index 9bb57d44808..00000000000 --- a/beater/middleware/request_metadata_middleware.go +++ /dev/null @@ -1,46 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package middleware - -import ( - "github.com/elastic/apm-server/beater/request" - "github.com/elastic/apm-server/utility" -) - -// UserMetadataMiddleware returns a Middleware recording request-level -// user metadata (e.g. user-agent and source IP) in the request's context. -func UserMetadataMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.RequestMetadata.UserAgent = utility.UserAgentHeader(c.Request.Header) - c.RequestMetadata.ClientIP = utility.ExtractIP(c.Request) - h(c) - }, nil - } -} - -// SystemMetadataMiddleware returns a Middleware recording request-level -// system metadata (e.g. source IP) in the request's context. -func SystemMetadataMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.RequestMetadata.SystemIP = utility.ExtractIP(c.Request) - h(c) - }, nil - } -} diff --git a/beater/middleware/request_metadata_middleware_test.go b/beater/middleware/request_metadata_middleware_test.go deleted file mode 100644 index ac69123c5c3..00000000000 --- a/beater/middleware/request_metadata_middleware_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package middleware - -import ( - "fmt" - "net" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/elastic/apm-server/beater/beatertest" -) - -func TestUserMetadataMiddleware(t *testing.T) { - type test struct { - remoteAddr string - userAgent []string - expectedIP net.IP - expectedUserAgent string - } - - ua1 := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36" - ua2 := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:67.0) Gecko/20100101 Firefox/67.0" - tests := []test{ - {remoteAddr: "1.2.3.4:1234", expectedIP: net.ParseIP("1.2.3.4"), userAgent: []string{ua1, ua2}, expectedUserAgent: fmt.Sprintf("%s, %s", ua1, ua2)}, - {remoteAddr: "not-an-ip:1234", userAgent: []string{ua1}, expectedUserAgent: ua1}, - {remoteAddr: ""}, - } - - for _, test := range tests { - c, _ := beatertest.DefaultContextWithResponseRecorder() - c.Request.RemoteAddr = test.remoteAddr - for _, ua := range test.userAgent { - c.Request.Header.Add("User-Agent", ua) - } - - Apply(UserMetadataMiddleware(), beatertest.HandlerIdle)(c) - assert.Equal(t, test.expectedUserAgent, c.RequestMetadata.UserAgent) - assert.Equal(t, test.expectedIP, c.RequestMetadata.ClientIP) - } -} - -func TestSystemMetadataMiddleware(t *testing.T) { - type test struct { - remoteAddr string - expectedIP net.IP - } - tests := []test{ - {remoteAddr: "1.2.3.4:1234", expectedIP: net.ParseIP("1.2.3.4")}, - {remoteAddr: "not-an-ip:1234"}, - {remoteAddr: ""}, - } - - for _, test := range tests { - c, _ := beatertest.DefaultContextWithResponseRecorder() - c.Request.RemoteAddr = test.remoteAddr - - Apply(SystemMetadataMiddleware(), beatertest.HandlerIdle)(c) - assert.Equal(t, test.expectedIP, c.RequestMetadata.SystemIP) - } -} diff --git a/beater/request/context.go b/beater/request/context.go index 118643ef412..4478ce9a8e3 100644 --- a/beater/request/context.go +++ b/beater/request/context.go @@ -23,13 +23,12 @@ import ( "net/http" "strings" - "golang.org/x/time/rate" - "github.com/elastic/beats/v7/libbeat/logp" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/headers" logs "github.com/elastic/apm-server/log" + "github.com/elastic/apm-server/utility" ) const ( @@ -43,26 +42,17 @@ var ( // Context abstracts request and response information for http requests type Context struct { - Request *http.Request - Logger *logp.Logger - RateLimiter *rate.Limiter - AuthResult authorization.Result - IsRum bool - Result Result - RequestMetadata Metadata + Request *http.Request + Logger *logp.Logger + AuthResult authorization.Result + Result Result + SourceIP net.IP + UserAgent string w http.ResponseWriter writeAttempts int } -// Metadata contains metadata extracted from the request by middleware, -// and should be merged into the event metadata. -type Metadata struct { - ClientIP net.IP - SystemIP net.IP - UserAgent string -} - // NewContext creates an empty Context struct func NewContext() *Context { return &Context{} @@ -72,23 +62,15 @@ func NewContext() *Context { func (c *Context) Reset(w http.ResponseWriter, r *http.Request) { c.Request = r c.Logger = nil - c.RateLimiter = nil c.AuthResult = authorization.Result{} - c.IsRum = false c.Result.Reset() - c.RequestMetadata.Reset() + c.SourceIP = utility.ExtractIP(r) + c.UserAgent = utility.UserAgentHeader(r.Header) c.w = w c.writeAttempts = 0 } -// Reset sets all attribtues of the Metadata instance to it's zero value -func (m *Metadata) Reset() { - m.ClientIP = nil - m.SystemIP = nil - m.UserAgent = "" -} - // Header returns the http.Header of the context's writer func (c *Context) Header() http.Header { return c.w.Header() diff --git a/beater/request/context_test.go b/beater/request/context_test.go index aa916ac5158..2ec5eb9e1ed 100644 --- a/beater/request/context_test.go +++ b/beater/request/context_test.go @@ -18,6 +18,7 @@ package request import ( + "net" "net/http" "net/http/httptest" "reflect" @@ -37,7 +38,11 @@ func TestContext_Reset(t *testing.T) { w1.WriteHeader(http.StatusServiceUnavailable) w2 := httptest.NewRecorder() r1 := httptest.NewRequest(http.MethodGet, "/", nil) + r1.RemoteAddr = "10.1.2.3:4321" + r1.Header.Set("User-Agent", "ua1") r2 := httptest.NewRequest(http.MethodHead, "/new", nil) + r2.RemoteAddr = "10.1.2.3:1234" + r2.Header.Set("User-Agent", "ua2") c := Context{ Request: r1, w: w1, @@ -65,8 +70,10 @@ func TestContext_Reset(t *testing.T) { assert.Equal(t, 0, c.writeAttempts) case "Result": assertResultIsEmpty(t, cVal.Field(i).Interface().(Result)) - case "RequestMetadata": - assert.Equal(t, Metadata{}, cVal.Field(i).Interface().(Metadata)) + case "SourceIP": + assert.Equal(t, net.ParseIP("10.1.2.3"), cVal.Field(i).Interface()) + case "UserAgent": + assert.Equal(t, "ua2", cVal.Field(i).Interface()) default: assert.Empty(t, cVal.Field(i).Interface(), cType.Field(i).Name) } diff --git a/beater/server.go b/beater/server.go index 7d144321402..11179ebb31b 100644 --- a/beater/server.go +++ b/beater/server.go @@ -32,6 +32,7 @@ import ( "github.com/elastic/beats/v7/libbeat/version" "github.com/elastic/apm-server/agentcfg" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/interceptors" @@ -116,16 +117,24 @@ func newServer( reporter publish.Reporter, batchProcessor model.BatchProcessor, ) (server, error) { - fetcher := agentcfg.NewFetcher(cfg) - httpServer, err := newHTTPServer(logger, info, cfg, tracer, reporter, batchProcessor, fetcher) + agentcfgFetcher := agentcfg.NewFetcher(cfg) + ratelimitStore, err := ratelimit.NewStore( + cfg.RumConfig.EventRate.LruSize, + cfg.RumConfig.EventRate.Limit, + 3, // burst multiplier + ) if err != nil { return server{}, err } - grpcServer, err := newGRPCServer(logger, cfg, tracer, batchProcessor, httpServer.TLSConfig, fetcher) + httpServer, err := newHTTPServer(logger, info, cfg, tracer, reporter, batchProcessor, agentcfgFetcher, ratelimitStore) if err != nil { return server{}, err } - jaegerServer, err := jaeger.NewServer(logger, cfg, tracer, batchProcessor, fetcher) + grpcServer, err := newGRPCServer(logger, cfg, tracer, batchProcessor, httpServer.TLSConfig, agentcfgFetcher, ratelimitStore) + if err != nil { + return server{}, err + } + jaegerServer, err := jaeger.NewServer(logger, cfg, tracer, batchProcessor, agentcfgFetcher) if err != nil { return server{}, err } @@ -144,7 +153,8 @@ func newGRPCServer( tracer *apm.Tracer, batchProcessor model.BatchProcessor, tlsConfig *tls.Config, - fetcher agentcfg.Fetcher, + agentcfgFetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, ) (*grpc.Server, error) { // TODO(axw) share auth builder with beater/api. authBuilder, err := authorization.NewBuilder(cfg.AgentAuth) @@ -152,10 +162,11 @@ func newGRPCServer( return nil, err } - // NOTE(axw) even if TLS is enabled we should not use grpc.Creds, as TLS is handled by the net/http server. apmInterceptor := apmgrpc.NewUnaryServerInterceptor(apmgrpc.WithRecovery(), apmgrpc.WithTracer(tracer)) authInterceptor := newAuthUnaryServerInterceptor(authBuilder) + // Note that we intentionally do not use a grpc.Creds ServerOption + // even if TLS is enabled, as TLS is handled by the net/http server. logger = logger.Named("grpc") srv := grpc.NewServer( grpc.ChainUnaryInterceptor( @@ -165,6 +176,9 @@ func newGRPCServer( interceptors.Metrics(logger, otlp.RegistryMonitoringMaps, jaeger.RegistryMonitoringMaps), interceptors.Timeout(), authInterceptor, + + // TODO(axw) add a rate limiting interceptor here once we've + // updated authInterceptor to handle auth for Jaeger requests. ), ) @@ -182,7 +196,7 @@ func newGRPCServer( batchProcessor, } - jaeger.RegisterGRPCServices(srv, authBuilder, jaeger.ElasticAuthTag, logger, batchProcessor, fetcher) + jaeger.RegisterGRPCServices(srv, authBuilder, jaeger.ElasticAuthTag, logger, batchProcessor, agentcfgFetcher) if err := otlp.RegisterGRPCServices(srv, batchProcessor); err != nil { return nil, err } diff --git a/beater/tracing.go b/beater/tracing.go index 8d37f979ea5..6a9f791795b 100644 --- a/beater/tracing.go +++ b/beater/tracing.go @@ -27,6 +27,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/beater/api" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/model" "github.com/elastic/apm-server/publish" @@ -59,7 +60,18 @@ func newTracerServer(listener net.Listener, logger *logp.Logger) (*tracerServer, } }) cfg := config.DefaultConfig() - mux, err := api.NewMux(beat.Info{}, cfg, nopReporter, processBatch, agentcfg.NewFetcher(cfg)) + ratelimitStore, err := ratelimit.NewStore(1, 1, 1) // unused, arbitrary params + if err != nil { + return nil, err + } + mux, err := api.NewMux( + beat.Info{}, + cfg, + nopReporter, + processBatch, + agentcfg.NewFetcher(cfg), + ratelimitStore, + ) if err != nil { return nil, err } diff --git a/systemtest/rum_test.go b/systemtest/rum_test.go index 7f72a3ea6d2..4d82b0a71ef 100644 --- a/systemtest/rum_test.go +++ b/systemtest/rum_test.go @@ -192,7 +192,12 @@ func TestRUMRateLimit(t *testing.T) { g.Go(func() error { return sendEvents("10.11.12.13", srv.Config.RUM.RateLimit.EventLimit) }) g.Go(func() error { return sendEvents("10.11.12.14", srv.Config.RUM.RateLimit.EventLimit) }) g.Go(func() error { return sendEvents("10.11.12.15", srv.Config.RUM.RateLimit.EventLimit) }) - assert.EqualError(t, g.Wait(), `429 Too Many Requests ({"accepted":0,"errors":[{"message":"rate limit exceeded"}]})`) + err = g.Wait() + require.Error(t, err) + + // The exact error differs, depending on whether rate limiting was applied at the request + // level, or at the event stream level. Either could occur. + assert.Regexp(t, `429 Too Many Requests .*`, err.Error()) } func sendRUMEventsPayload(t *testing.T, srv *apmservertest.Server, payloadFile string) {