diff --git a/beater/api/config/agent/handler.go b/beater/api/config/agent/handler.go index 04086dcc9d2..f77a39f24fa 100644 --- a/beater/api/config/agent/handler.go +++ b/beater/api/config/agent/handler.go @@ -52,18 +52,21 @@ var ( registry = monitoring.Default.NewRegistry("apm-server.acm") errCacheControl = fmt.Sprintf("max-age=%v, must-revalidate", errMaxAgeDuration.Seconds()) - - // rumAgents keywords (new and old) - rumAgents = []string{"rum-js", "js-base"} ) type handler struct { f agentcfg.Fetcher + allowAnonymousAgents []string cacheControl, defaultServiceEnvironment string } -func NewHandler(f agentcfg.Fetcher, config config.KibanaAgentConfig, defaultServiceEnvironment string) request.Handler { +func NewHandler( + f agentcfg.Fetcher, + config config.KibanaAgentConfig, + defaultServiceEnvironment string, + allowAnonymousAgents []string, +) request.Handler { if f == nil { panic("fetcher must not be nil") } @@ -72,6 +75,7 @@ func NewHandler(f agentcfg.Fetcher, config config.KibanaAgentConfig, defaultServ f: f, cacheControl: cacheControl, defaultServiceEnvironment: defaultServiceEnvironment, + allowAnonymousAgents: allowAnonymousAgents, } return h.Handle @@ -83,13 +87,6 @@ func (h *handler) Handle(c *request.Context) { // error handling c.Header().Set(headers.CacheControl, errCacheControl) - ok := c.RateLimiter == nil || c.RateLimiter.Allow() - if !ok { - c.Result.SetDefault(request.IDResponseErrorsRateLimit) - c.Write() - return - } - query, queryErr := buildQuery(c) if queryErr != nil { extractQueryError(c, queryErr) @@ -100,26 +97,29 @@ func (h *handler) Handle(c *request.Context) { query.Service.Environment = h.defaultServiceEnvironment } - if !c.AuthResult.Anonymous { - // The exact agent is not always known for anonymous clients, so we do not - // issue a secondary authorization check for them. Instead, we issue the - // request and filter the results using query.InsecureAgents. - authResource := authorization.Resource{ServiceName: query.Service.Name} - if result, err := authorization.AuthorizedFor(c.Request.Context(), authResource); err != nil { - c.Result.SetDefault(request.IDResponseErrorsServiceUnavailable) - c.Result.Err = err - c.Write() - return - } else if !result.Authorized { - id := request.IDResponseErrorsUnauthorized - status := request.MapResultIDToStatus[id] - if result.Reason != "" { - status.Keyword = result.Reason - } - c.Result.Set(id, status.Code, status.Keyword, nil, nil) - c.Write() - return + // Only service, and not agent, is known for config queries. + // For anonymous/untrusted agents, we filter the results using + // query.InsecureAgents below. + authResource := authorization.Resource{ServiceName: query.Service.Name} + authResult, err := authorization.AuthorizedFor(c.Request.Context(), authResource) + if err != nil { + c.Result.SetDefault(request.IDResponseErrorsServiceUnavailable) + c.Result.Err = err + c.Write() + return + } + if !authResult.Authorized { + id := request.IDResponseErrorsUnauthorized + status := request.MapResultIDToStatus[id] + if authResult.Reason != "" { + status.Keyword = authResult.Reason } + c.Result.Set(id, status.Code, status.Keyword, nil, nil) + c.Write() + return + } + if authResult.Anonymous { + query.InsecureAgents = h.allowAnonymousAgents } result, err := h.f.Fetch(c.Request.Context(), query) @@ -184,9 +184,6 @@ func buildQuery(c *request.Context) (agentcfg.Query, error) { return query, errors.New(agentcfg.ServiceName + " is required") } - if c.IsRum { - query.InsecureAgents = rumAgents - } query.Etag = ifNoneMatch(c) return query, nil } diff --git a/beater/api/config/agent/handler_test.go b/beater/api/config/agent/handler_test.go index 7dcbc1dc079..7ab065f84d3 100644 --- a/beater/api/config/agent/handler_test.go +++ b/beater/api/config/agent/handler_test.go @@ -32,7 +32,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.elastic.co/apm/apmtest" - "golang.org/x/time/rate" "github.com/elastic/beats/v7/libbeat/common" libkibana "github.com/elastic/beats/v7/libbeat/kibana" @@ -173,18 +172,12 @@ func TestAgentConfigHandler(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: 4 * time.Second}} for _, tc := range testcases { f := agentcfg.NewKibanaFetcher(tc.kbClient, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) r := httptest.NewRequest(tc.method, target(tc.queryParams), nil) for k, v := range tc.requestHeader { r.Header.Set(k, v) } ctx, w := newRequestContext(r) - ctx.AuthResult.Authorized = true - ctx.Request = withAuthorization(ctx.Request, - authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { - return authorization.Result{Authorized: true}, nil - }), - ) h(ctx) require.Equal(t, tc.respStatus, w.Code) @@ -202,32 +195,46 @@ func TestAgentConfigHandlerAnonymousAccess(t *testing.T) { kbClient := kibanatest.MockKibana(http.StatusUnauthorized, m{"error": "Unauthorized"}, mockVersion, true) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kbClient, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) for _, tc := range []struct { - anonymous bool - response string + anonymous bool + response string + authResource *authorization.Resource }{{ - anonymous: false, - response: `{"error":"APM Server is not authorized to query Kibana. Please configure apm-server.kibana.username and apm-server.kibana.password, and ensure the user has the necessary privileges."}`, + anonymous: false, + response: `{"error":"APM Server is not authorized to query Kibana. Please configure apm-server.kibana.username and apm-server.kibana.password, and ensure the user has the necessary privileges."}`, + authResource: &authorization.Resource{ServiceName: "opbeans"}, }, { - anonymous: true, - response: `{"error":"Unauthorized"}`, + anonymous: true, + response: `{"error":"Unauthorized"}`, + authResource: &authorization.Resource{ServiceName: "opbeans"}, }} { r := httptest.NewRequest(http.MethodGet, target(map[string]string{"service.name": "opbeans"}), nil) - ctx, w := newRequestContext(r) - ctx.AuthResult.Authorized = true - ctx.AuthResult.Anonymous = tc.anonymous - ctx.Request = withAuthorization(ctx.Request, authorization.AnonymousAuth{}) - h(ctx) + c, w := newRequestContext(r) + c.AuthResult.Authorized = true + c.AuthResult.Anonymous = tc.anonymous + + var requestedResource *authorization.Resource + c.Request = withAuthorization(c.Request, + authorizedForFunc(func(ctx context.Context, resource authorization.Resource) (authorization.Result, error) { + if requestedResource != nil { + panic("expected only one AuthorizedFor request") + } + requestedResource = &resource + return c.AuthResult, nil + }), + ) + h(c) assert.Equal(t, tc.response+"\n", w.Body.String()) + assert.Equal(t, tc.authResource, requestedResource) } } func TestAgentConfigHandlerAuthorizedForService(t *testing.T) { cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(nil, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) r := httptest.NewRequest(http.MethodGet, target(map[string]string{"service.name": "opbeans"}), nil) ctx, w := newRequestContext(r) @@ -249,7 +256,7 @@ func TestAgentConfigHandlerAuthorizedForService(t *testing.T) { func TestAgentConfigHandler_NoKibanaClient(t *testing.T) { cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(nil, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) w := sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{ "service": m{"name": "opbeans-node"}}))) @@ -268,7 +275,7 @@ func TestAgentConfigHandler_PostOk(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "") + h := NewHandler(f, cfg, "", nil) w := sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{ "service": m{"name": "opbeans-node"}}))) @@ -289,7 +296,7 @@ func TestAgentConfigHandler_DefaultServiceEnvironment(t *testing.T) { var cfg = config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - h := NewHandler(f, cfg, "default") + h := NewHandler(f, cfg, "default", nil) sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{"service": m{"name": "opbeans-node", "environment": "specified"}}))) sendRequest(h, httptest.NewRequest(http.MethodPost, "/config", convert.ToReader(m{"service": m{"name": "opbeans-node"}}))) @@ -306,8 +313,6 @@ func TestAgentConfigRum(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ "service": m{"name": "opbeans"}})) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) var actual map[string]string json.Unmarshal(w.Body.Bytes(), &actual) @@ -320,8 +325,6 @@ func TestAgentConfigRumEtag(t *testing.T) { h := getHandler("rum-js") r := httptest.NewRequest(http.MethodGet, "/rum?ifnonematch=123&service.name=opbeans", nil) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) assert.Equal(t, http.StatusNotModified, w.Code, w.Body.String()) } @@ -333,7 +336,7 @@ func TestAgentConfigNotRum(t *testing.T) { ctx, w := newRequestContext(r) ctx.Request = withAuthorization(ctx.Request, authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { - return authorization.Result{Authorized: true}, nil + return authorization.Result{Authorized: true, Anonymous: false}, nil }), ) h(ctx) @@ -348,8 +351,6 @@ func TestAgentConfigNoLeak(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ "service": m{"name": "opbeans"}})) ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.AuthResult.Anonymous = true h(ctx) var actual map[string]string json.Unmarshal(w.Body.Bytes(), &actual) @@ -357,21 +358,6 @@ func TestAgentConfigNoLeak(t *testing.T) { assert.Equal(t, map[string]string{}, actual) } -func TestAgentConfigRateLimit(t *testing.T) { - h := getHandler("rum-js") - r := httptest.NewRequest(http.MethodPost, "/rum", convert.ToReader(m{ - "service": m{"name": "opbeans"}})) - ctx, w := newRequestContext(r) - ctx.IsRum = true - ctx.RateLimiter = rate.NewLimiter(rate.Limit(0), 0) - ctx.AuthResult.Anonymous = true - h(ctx) - var actual map[string]string - json.Unmarshal(w.Body.Bytes(), &actual) - assert.Equal(t, http.StatusTooManyRequests, w.Code, w.Body.String()) - assert.Equal(t, map[string]string{"error": "too many requests"}, actual) -} - func getHandler(agent string) request.Handler { kb := kibanatest.MockKibana(http.StatusOK, m{ "_id": "1", @@ -386,7 +372,7 @@ func getHandler(agent string) request.Handler { }, mockVersion, true) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: time.Nanosecond}} f := agentcfg.NewKibanaFetcher(kb, cfg.Cache.Expiration) - return NewHandler(f, cfg, "") + return NewHandler(f, cfg, "", []string{"rum-js"}) } func TestIfNoneMatch(t *testing.T) { @@ -412,7 +398,7 @@ func TestAgentConfigTraceContext(t *testing.T) { client := kibana.NewConnectingClient(&kibanaCfg) cfg := config.KibanaAgentConfig{Cache: config.Cache{Expiration: 5 * time.Minute}} f := agentcfg.NewKibanaFetcher(client, cfg.Cache.Expiration) - handler := NewHandler(f, cfg, "default") + handler := NewHandler(f, cfg, "default", nil) _, spans, _ := apmtest.WithTransaction(func(ctx context.Context) { // When the handler is called with a context containing // a transaction, the underlying Kibana query should create a span @@ -439,6 +425,7 @@ func newRequestContext(r *http.Request) (*request.Context, *httptest.ResponseRec w := httptest.NewRecorder() ctx := request.NewContext() ctx.Reset(w, r) + ctx.Request = withAnonymousAuthorization(ctx.Request) return ctx, w } @@ -471,6 +458,12 @@ func (c *recordingKibanaClient) Send(ctx context.Context, method string, path st return c.Client.Send(ctx, method, path, params, header, body) } +func withAnonymousAuthorization(req *http.Request) *http.Request { + return withAuthorization(req, authorizedForFunc(func(context.Context, authorization.Resource) (authorization.Result, error) { + return authorization.Result{Authorized: true, Anonymous: true}, nil + })) +} + func withAuthorization(req *http.Request, auth authorization.Authorization) *http.Request { return req.WithContext(authorization.ContextWithAuthorization(req.Context(), auth)) } diff --git a/beater/api/intake/handler.go b/beater/api/intake/handler.go index 2bfb47b7037..3d9a4d54937 100644 --- a/beater/api/intake/handler.go +++ b/beater/api/intake/handler.go @@ -30,6 +30,7 @@ import ( "github.com/elastic/beats/v7/libbeat/monitoring" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/request" "github.com/elastic/apm-server/decoder" @@ -52,7 +53,6 @@ var ( errMethodNotAllowed = errors.New("only POST requests are supported") errServerShuttingDown = errors.New("server is shutting down") errInvalidContentType = errors.New("invalid content type") - errRateLimitExceeded = errors.New("rate limit exceeded") ) // StreamHandler is an interface for handling an Elastic APM agent ND-JSON event @@ -68,27 +68,23 @@ type StreamHandler interface { ) error } +// RequestMetadataFunc is a function type supplied to Handler for extracting +// metadata from the request. This is used for conditionally injecting the +// source IP address as `client.ip` for RUM. +type RequestMetadataFunc func(*request.Context) model.Metadata + // Handler returns a request.Handler for managing intake requests for backend and rum events. -func Handler(handler StreamHandler, batchProcessor model.BatchProcessor) request.Handler { +func Handler(handler StreamHandler, requestMetadataFunc RequestMetadataFunc, batchProcessor model.BatchProcessor) request.Handler { return func(c *request.Context) { if err := validateRequest(c); err != nil { writeError(c, err) return } - if c.RateLimiter != nil { - // Call Allow once for each stream to avoid burning CPU before we - // begin reading for the first time. This prevents clients from - // repeatedly connecting and sending < batchSize events and - // disconnecting before being rate limited. - if !c.RateLimiter.Allow() { - writeError(c, errRateLimitExceeded) - return - } - + if limiter, ok := ratelimit.FromContext(c.Request.Context()); ok { // Apply rate limiting after reading but before processing any events. batchProcessor = modelprocessor.Chained{ - rateLimitBatchProcessor(c.RateLimiter, batchSize), + rateLimitBatchProcessor(limiter, batchSize), batchProcessor, } } @@ -99,12 +95,7 @@ func Handler(handler StreamHandler, batchProcessor model.BatchProcessor) request return } - metadata := model.Metadata{ - UserAgent: model.UserAgent{Original: c.RequestMetadata.UserAgent}, - Client: model.Client{IP: c.RequestMetadata.ClientIP}, - System: model.System{IP: c.RequestMetadata.SystemIP}, - } - + metadata := requestMetadataFunc(c) var result stream.Result if err := handler.HandleStream( c.Request.Context(), @@ -132,7 +123,7 @@ func rateLimitBatch(ctx context.Context, limiter *rate.Limiter, batchSize int) e ctx, cancel := context.WithTimeout(ctx, rateLimitTimeout) defer cancel() if err := limiter.WaitN(ctx, batchSize); err != nil { - return errRateLimitExceeded + return ratelimit.ErrRateLimitExceeded } return nil } @@ -191,7 +182,7 @@ func writeStreamResult(c *request.Context, sr *stream.Result) { errID = request.IDResponseErrorsMethodNotAllowed case errors.Is(err, errInvalidContentType): errID = request.IDResponseErrorsValidate - case errors.Is(err, errRateLimitExceeded): + case errors.Is(err, ratelimit.ErrRateLimitExceeded): errID = request.IDResponseErrorsRateLimit } } diff --git a/beater/api/intake/handler_test.go b/beater/api/intake/handler_test.go index 5a4b7cf1b4d..ed4074520f4 100644 --- a/beater/api/intake/handler_test.go +++ b/beater/api/intake/handler_test.go @@ -34,6 +34,7 @@ import ( "golang.org/x/time/rate" "github.com/elastic/apm-server/approvaltest" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/request" @@ -136,7 +137,7 @@ func TestIntakeHandler(t *testing.T) { tc.setup(t) // call handler - h := Handler(tc.processor, tc.batchProcessor) + h := Handler(tc.processor, emptyRequestMetadata, tc.batchProcessor) h(tc.c) require.Equal(t, string(tc.id), string(tc.c.Result.ID)) @@ -190,12 +191,15 @@ func TestRateLimiting(t *testing.T) { var tc testcaseIntakeHandler tc.path = "ratelimit.ndjson" tc.setup(t) - tc.c.RateLimiter = test.limiter + + tc.c.Request = tc.c.Request.WithContext( + ratelimit.ContextWithLimiter(tc.c.Request.Context(), test.limiter), + ) if test.preconsumed > 0 { test.limiter.AllowN(time.Now(), test.preconsumed) } - h := Handler(tc.processor, tc.batchProcessor) + h := Handler(tc.processor, emptyRequestMetadata, tc.batchProcessor) h(tc.c) if test.expectLimited { @@ -275,3 +279,7 @@ func compressedRequest(t *testing.T, compressionType string, compressPayload boo req.Header.Set(headers.ContentEncoding, compressionType) return req } + +func emptyRequestMetadata(*request.Context) model.Metadata { + return model.Metadata{} +} diff --git a/beater/api/mux.go b/beater/api/mux.go index 55e9c1ee73f..436a96d515c 100644 --- a/beater/api/mux.go +++ b/beater/api/mux.go @@ -31,6 +31,7 @@ import ( "github.com/elastic/apm-server/beater/api/config/agent" "github.com/elastic/apm-server/beater/api/intake" "github.com/elastic/apm-server/beater/api/profile" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/api/root" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/config" @@ -68,6 +69,13 @@ const ( IntakeRUMV3Path = "/intake/v3/rum/events" ) +var ( + // rumAgents holds the current and previous agent names for the + // RUM JavaScript agent. This is used for restricting which config + // is supplied to anonymous agents. + rumAgents = []string{"rum-js", "js-base"} +) + // NewMux registers apm handlers to paths building up the APM Server API. func NewMux( beatInfo beat.Info, @@ -75,6 +83,7 @@ func NewMux( report publish.Reporter, batchProcessor model.BatchProcessor, fetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, ) (*http.ServeMux, error) { pool := request.NewContextPool() mux := http.NewServeMux() @@ -91,6 +100,7 @@ func NewMux( authBuilder: auth, reporter: report, batchProcessor: batchProcessor, + ratelimitStore: ratelimitStore, } type route struct { @@ -140,33 +150,46 @@ type routeBuilder struct { authBuilder *authorization.Builder reporter publish.Reporter batchProcessor model.BatchProcessor + ratelimitStore *ratelimit.Store } func (r *routeBuilder) profileHandler() (request.Handler, error) { - h := profile.Handler(r.batchProcessor) + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = backendRequestMetadata + } + h := profile.Handler(requestMetadataFunc, r.batchProcessor) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeEventWrite.Action) - return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, profile.MonitoringMap)...) + return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, r.ratelimitStore, profile.MonitoringMap)...) } func (r *routeBuilder) backendIntakeHandler() (request.Handler, error) { - h := intake.Handler(stream.BackendProcessor(r.cfg), r.batchProcessor) + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = backendRequestMetadata + } + h := intake.Handler(stream.BackendProcessor(r.cfg), requestMetadataFunc, r.batchProcessor) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeEventWrite.Action) - return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, intake.MonitoringMap)...) + return middleware.Wrap(h, backendMiddleware(r.cfg, authHandler, r.ratelimitStore, intake.MonitoringMap)...) } func (r *routeBuilder) rumIntakeHandler(newProcessor func(*config.Config) *stream.Processor) func() (request.Handler, error) { + requestMetadataFunc := emptyRequestMetadata + if r.cfg.AugmentEnabled { + requestMetadataFunc = rumRequestMetadata + } return func() (request.Handler, error) { batchProcessor := r.batchProcessor batchProcessor = batchProcessorWithAllowedServiceNames(batchProcessor, r.cfg.RumConfig.AllowServiceNames) - h := intake.Handler(newProcessor(r.cfg), batchProcessor) - return middleware.Wrap(h, rumMiddleware(r.cfg, nil, intake.MonitoringMap)...) + h := intake.Handler(newProcessor(r.cfg), requestMetadataFunc, batchProcessor) + return middleware.Wrap(h, rumMiddleware(r.cfg, nil, r.ratelimitStore, intake.MonitoringMap)...) } } func (r *routeBuilder) sourcemapHandler() (request.Handler, error) { h := sourcemap.Handler(r.reporter) authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeSourcemapWrite.Action) - return middleware.Wrap(h, sourcemapMiddleware(r.cfg, authHandler)...) + return middleware.Wrap(h, sourcemapMiddleware(r.cfg, authHandler, r.ratelimitStore)...) } func (r *routeBuilder) rootHandler() (request.Handler, error) { @@ -177,26 +200,27 @@ func (r *routeBuilder) rootHandler() (request.Handler, error) { func (r *routeBuilder) backendAgentConfigHandler(f agentcfg.Fetcher) func() (request.Handler, error) { return func() (request.Handler, error) { authHandler := r.authBuilder.ForPrivilege(authorization.PrivilegeAgentConfigRead.Action) - return agentConfigHandler(r.cfg, authHandler, backendMiddleware, f) + return agentConfigHandler(r.cfg, authHandler, r.ratelimitStore, backendMiddleware, f) } } func (r *routeBuilder) rumAgentConfigHandler(f agentcfg.Fetcher) func() (request.Handler, error) { return func() (request.Handler, error) { - return agentConfigHandler(r.cfg, nil, rumMiddleware, f) + return agentConfigHandler(r.cfg, nil, r.ratelimitStore, rumMiddleware, f) } } -type middlewareFunc func(*config.Config, *authorization.Handler, map[request.ResultID]*monitoring.Int) []middleware.Middleware +type middlewareFunc func(*config.Config, *authorization.Handler, *ratelimit.Store, map[request.ResultID]*monitoring.Int) []middleware.Middleware func agentConfigHandler( cfg *config.Config, authHandler *authorization.Handler, + ratelimitStore *ratelimit.Store, middlewareFunc middlewareFunc, f agentcfg.Fetcher, ) (request.Handler, error) { - mw := middlewareFunc(cfg, authHandler, agent.MonitoringMap) - h := agent.NewHandler(f, cfg.KibanaAgentConfig, cfg.DefaultServiceEnvironment) + mw := middlewareFunc(cfg, authHandler, ratelimitStore, agent.MonitoringMap) + h := agent.NewHandler(f, cfg.KibanaAgentConfig, cfg.DefaultServiceEnvironment, rumAgents) if !cfg.Kibana.Enabled && cfg.AgentConfigs == nil { msg := "Agent remote configuration is disabled. " + @@ -219,37 +243,30 @@ func apmMiddleware(m map[request.ResultID]*monitoring.Int) []middleware.Middlewa } } -func backendMiddleware(cfg *config.Config, auth *authorization.Handler, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { +func backendMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { backendMiddleware := append(apmMiddleware(m), middleware.ResponseHeadersMiddleware(cfg.ResponseHeaders), middleware.AuthorizationMiddleware(auth, true), + middleware.AnonymousRateLimitMiddleware(ratelimitStore), ) - if cfg.AugmentEnabled { - backendMiddleware = append(backendMiddleware, middleware.SystemMetadataMiddleware()) - } return backendMiddleware } -func rumMiddleware(cfg *config.Config, _ *authorization.Handler, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { +func rumMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store, m map[request.ResultID]*monitoring.Int) []middleware.Middleware { msg := "RUM endpoint is disabled. " + "Configure the `apm-server.rum` section in apm-server.yml to enable ingestion of RUM events. " + "If you are not using the RUM agent, you can safely ignore this error." rumMiddleware := append(apmMiddleware(m), middleware.ResponseHeadersMiddleware(cfg.ResponseHeaders), middleware.ResponseHeadersMiddleware(cfg.RumConfig.ResponseHeaders), - middleware.SetRumFlagMiddleware(), - middleware.SetIPRateLimitMiddleware(cfg.RumConfig.EventRate), middleware.CORSMiddleware(cfg.RumConfig.AllowOrigins, cfg.RumConfig.AllowHeaders), middleware.AnonymousAuthorizationMiddleware(), - middleware.KillSwitchMiddleware(cfg.RumConfig.Enabled, msg), + middleware.AnonymousRateLimitMiddleware(ratelimitStore), ) - if cfg.AugmentEnabled { - rumMiddleware = append(rumMiddleware, middleware.UserMetadataMiddleware()) - } - return rumMiddleware + return append(rumMiddleware, middleware.KillSwitchMiddleware(cfg.RumConfig.Enabled, msg)) } -func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler) []middleware.Middleware { +func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler, ratelimitStore *ratelimit.Store) []middleware.Middleware { msg := "Sourcemap upload endpoint is disabled. " + "Configure the `apm-server.rum` section in apm-server.yml to enable sourcemap uploads. " + "If you are not using the RUM agent, you can safely ignore this error." @@ -257,8 +274,8 @@ func sourcemapMiddleware(cfg *config.Config, auth *authorization.Handler) []midd msg = "When APM Server is managed by Fleet, Sourcemaps must be uploaded directly to Elasticsearch." } enabled := cfg.RumConfig.Enabled && cfg.RumConfig.SourceMapping.Enabled && !cfg.DataStreams.Enabled - return append(backendMiddleware(cfg, auth, sourcemap.MonitoringMap), - middleware.KillSwitchMiddleware(enabled, msg)) + backendMiddleware := backendMiddleware(cfg, auth, ratelimitStore, sourcemap.MonitoringMap) + return append(backendMiddleware, middleware.KillSwitchMiddleware(enabled, msg)) } func rootMiddleware(cfg *config.Config, auth *authorization.Handler) []middleware.Middleware { @@ -288,3 +305,18 @@ func batchProcessorWithAllowedServiceNames(p model.BatchProcessor, allowedServic } return modelprocessor.Chained{restrictServiceName, p} } + +func emptyRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{} +} + +func backendRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{System: model.System{IP: c.SourceIP}} +} + +func rumRequestMetadata(c *request.Context) model.Metadata { + return model.Metadata{ + Client: model.Client{IP: c.SourceIP}, + UserAgent: model.UserAgent{Original: c.UserAgent}, + } +} diff --git a/beater/api/mux_intake_rum_test.go b/beater/api/mux_intake_rum_test.go index 5bf99c1ee44..2f5f728a5b2 100644 --- a/beater/api/mux_intake_rum_test.go +++ b/beater/api/mux_intake_rum_test.go @@ -29,6 +29,7 @@ import ( "github.com/elastic/apm-server/approvaltest" "github.com/elastic/apm-server/beater/api/intake" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/headers" "github.com/elastic/apm-server/beater/middleware" @@ -36,6 +37,7 @@ import ( ) func TestOPTIONS(t *testing.T) { + ratelimitStore, _ := ratelimit.NewStore(1, 1, 1) requestTaken := make(chan struct{}, 1) done := make(chan struct{}, 1) @@ -46,7 +48,7 @@ func TestOPTIONS(t *testing.T) { requestTaken <- struct{}{} <-done }, - rumMiddleware(cfg, nil, intake.MonitoringMap)...) + rumMiddleware(cfg, nil, ratelimitStore, intake.MonitoringMap)...) // use this to block the single allowed concurrent requests go func() { diff --git a/beater/api/mux_test.go b/beater/api/mux_test.go index 8ce997c1eda..7b6be8969e8 100644 --- a/beater/api/mux_test.go +++ b/beater/api/mux_test.go @@ -28,6 +28,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/approvaltest" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/beatertest" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/request" @@ -76,7 +77,8 @@ func requestToMuxerWithHeaderAndQueryString( func requestToMuxer(cfg *config.Config, r *http.Request) (*httptest.ResponseRecorder, error) { nopReporter := func(context.Context, publish.PendingReq) error { return nil } nopBatchProcessor := model.ProcessBatchFunc(func(context.Context, *model.Batch) error { return nil }) - mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg)) + ratelimitStore, _ := ratelimit.NewStore(1000, 1000, 1000) + mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg), ratelimitStore) if err != nil { return nil, err } @@ -111,7 +113,8 @@ func testMonitoringMiddleware(t *testing.T, urlPath string, monitoringMap map[re func newTestMux(t *testing.T, cfg *config.Config) http.Handler { nopReporter := func(context.Context, publish.PendingReq) error { return nil } nopBatchProcessor := model.ProcessBatchFunc(func(context.Context, *model.Batch) error { return nil }) - mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg)) + ratelimitStore, _ := ratelimit.NewStore(1000, 1000, 1000) + mux, err := NewMux(beat.Info{Version: "1.2.3"}, cfg, nopReporter, nopBatchProcessor, agentcfg.NewFetcher(cfg), ratelimitStore) require.NoError(t, err) return mux } diff --git a/beater/api/profile/handler.go b/beater/api/profile/handler.go index bf1c10f7617..e1f30da55fd 100644 --- a/beater/api/profile/handler.go +++ b/beater/api/profile/handler.go @@ -55,8 +55,13 @@ const ( profileContentLengthLimit = 10 * 1024 * 1024 ) +// RequestMetadataFunc is a function type supplied to Handler for extracting +// metadata from the request. This is used for conditionally injecting the +// source IP address as `client.ip` for RUM. +type RequestMetadataFunc func(*request.Context) model.Metadata + // Handler returns a request.Handler for managing profile requests. -func Handler(processor model.BatchProcessor) request.Handler { +func Handler(requestMetadataFunc RequestMetadataFunc, processor model.BatchProcessor) request.Handler { handle := func(c *request.Context) (*result, error) { if c.Request.Method != http.MethodPost { return nil, requestError{ @@ -71,14 +76,6 @@ func Handler(processor model.BatchProcessor) request.Handler { } } - ok := c.RateLimiter == nil || c.RateLimiter.Allow() - if !ok { - return nil, requestError{ - id: request.IDResponseErrorsRateLimit, - err: errors.New("rate limit exceeded"), - } - } - var totalLimitRemaining int64 = profileContentLengthLimit var profiles []*pprof_profile.Profile var profileMetadata model.Metadata @@ -104,10 +101,7 @@ func Handler(processor model.BatchProcessor) request.Handler { } r := &decoder.LimitedReader{R: part, N: metadataContentLengthLimit} dec := decoder.NewJSONDecoder(r) - metadata := model.Metadata{ - UserAgent: model.UserAgent{Original: c.RequestMetadata.UserAgent}, - Client: model.Client{IP: c.RequestMetadata.ClientIP}, - System: model.System{IP: c.RequestMetadata.SystemIP}} + metadata := requestMetadataFunc(c) if err := v2.DecodeMetadata(dec, &metadata); err != nil { if r.N < 0 { return nil, requestError{ diff --git a/beater/api/profile/handler_test.go b/beater/api/profile/handler_test.go index 9a388eea4b3..0fa0f80eb52 100644 --- a/beater/api/profile/handler_test.go +++ b/beater/api/profile/handler_test.go @@ -31,7 +31,6 @@ import ( "strings" "testing" - "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/model" "github.com/stretchr/testify/assert" @@ -45,8 +44,6 @@ import ( const pprofContentType = `application/x-protobuf; messageType="perftools.profiles.Profile"` func TestHandler(t *testing.T) { - var rateLimit, err = ratelimit.NewStore(1, 0, 0) - require.NoError(t, err) for name, tc := range map[string]testcaseIntakeHandler{ "MethodNotAllowed": { r: httptest.NewRequest(http.MethodGet, "/", nil), @@ -60,10 +57,6 @@ func TestHandler(t *testing.T) { }(), id: request.IDResponseErrorsValidate, }, - "RateLimitExceeded": { - rateLimit: rateLimit, - id: request.IDResponseErrorsRateLimit, - }, "Closing": { batchProcessor: func(t *testing.T) model.BatchProcessor { return model.ProcessBatchFunc(func(context.Context, *model.Batch) error { @@ -199,10 +192,7 @@ func TestHandler(t *testing.T) { } { t.Run(name, func(t *testing.T) { tc.setup(t) - if tc.rateLimit != nil { - tc.c.RateLimiter = tc.rateLimit.ForIP(&http.Request{}) - } - Handler(tc.batchProcessor(t))(tc.c) + Handler(emptyRequestMetadata, tc.batchProcessor(t))(tc.c) assert.Equal(t, string(tc.id), string(tc.c.Result.ID)) resultStatus := request.MapResultIDToStatus[tc.id] @@ -225,7 +215,6 @@ type testcaseIntakeHandler struct { c *request.Context w *httptest.ResponseRecorder r *http.Request - rateLimit *ratelimit.Store batchProcessor func(t *testing.T) model.BatchProcessor reports int parts []part @@ -299,3 +288,7 @@ func prettyJSON(v interface{}) string { enc.Encode(v) return buf.String() } + +func emptyRequestMetadata(*request.Context) model.Metadata { + return model.Metadata{} +} diff --git a/beater/middleware/rum_middleware.go b/beater/api/ratelimit/context.go similarity index 52% rename from beater/middleware/rum_middleware.go rename to beater/api/ratelimit/context.go index be3f8f33d4d..d17f99597de 100644 --- a/beater/middleware/rum_middleware.go +++ b/beater/api/ratelimit/context.go @@ -15,18 +15,28 @@ // specific language governing permissions and limitations // under the License. -package middleware +package ratelimit import ( - "github.com/elastic/apm-server/beater/request" + "context" + + "github.com/pkg/errors" + "golang.org/x/time/rate" ) -// SetRumFlagMiddleware sets a rum flag in the context -func SetRumFlagMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.IsRum = true - h(c) - }, nil - } +// ErrRateLimitExceeded is returned when the rate limit is exceeded. +var ErrRateLimitExceeded = errors.New("rate limit exceeded") + +type rateLimiterKey struct{} + +// FromContext returns a rate.Limiter if one is contained in ctx, +// and a bool indicating whether one was found. +func FromContext(ctx context.Context) (*rate.Limiter, bool) { + limiter, ok := ctx.Value(rateLimiterKey{}).(*rate.Limiter) + return limiter, ok +} + +// ContextWithLimiter returns a copy of parent associated with limiter. +func ContextWithLimiter(parent context.Context, limiter *rate.Limiter) context.Context { + return context.WithValue(parent, rateLimiterKey{}, limiter) } diff --git a/beater/api/ratelimit/store.go b/beater/api/ratelimit/store.go index b9fac6c8cfd..660a3af76cd 100644 --- a/beater/api/ratelimit/store.go +++ b/beater/api/ratelimit/store.go @@ -18,11 +18,9 @@ package ratelimit import ( - "net/http" + "net" "sync" - "github.com/elastic/apm-server/utility" - "github.com/hashicorp/golang-lru/simplelru" "github.com/pkg/errors" "golang.org/x/time/rate" @@ -63,8 +61,9 @@ func NewStore(size, rateLimit, burstFactor int) (*Store, error) { return &store, nil } -// acquire returns a rate.Limiter instance for the given key -func (s *Store) acquire(key string) *rate.Limiter { +// ForIP returns a rate limiter for the given IP. +func (s *Store) ForIP(ip net.IP) *rate.Limiter { + key := ip.String() // lock get and add action for cache to allow proper eviction handling without // race conditions. @@ -83,11 +82,3 @@ func (s *Store) acquire(key string) *rate.Limiter { } return limiter } - -// ForIP returns a rate limiter for the given request IP -func (s *Store) ForIP(r *http.Request) *rate.Limiter { - if s == nil { - return nil - } - return s.acquire(utility.RemoteAddr(r)) -} diff --git a/beater/api/ratelimit/store_test.go b/beater/api/ratelimit/store_test.go index e6559e089f0..4706e9dbff8 100644 --- a/beater/api/ratelimit/store_test.go +++ b/beater/api/ratelimit/store_test.go @@ -18,7 +18,7 @@ package ratelimit import ( - "net/http" + "net" "testing" "time" @@ -38,7 +38,6 @@ func TestCacheInitFails(t *testing.T) { c, err := NewStore(test.size, test.limit, 3) assert.Error(t, err) assert.Nil(t, c) - assert.Nil(t, c.ForIP(&http.Request{})) } } @@ -50,20 +49,20 @@ func TestCacheEviction(t *testing.T) { require.NoError(t, err) // add new limiter - rlA := store.acquire("a") + rlA := store.ForIP(net.ParseIP("127.0.0.1")) rlA.AllowN(time.Now(), 3) // add new limiter - rlB := store.acquire("b") + rlB := store.ForIP(net.ParseIP("127.0.0.2")) rlB.AllowN(time.Now(), 2) // reuse evicted limiter rlA - rlC := store.acquire("c") + rlC := store.ForIP(net.ParseIP("127.0.0.3")) assert.False(t, rlC.Allow()) assert.Equal(t, rlC, store.evictedLimiter) // reuse evicted limiter rlB - rlD := store.acquire("a") + rlD := store.ForIP(net.ParseIP("127.0.0.1")) assert.True(t, rlD.Allow()) assert.False(t, rlD.Allow()) assert.Equal(t, rlD, store.evictedLimiter) @@ -77,22 +76,6 @@ func TestCacheEviction(t *testing.T) { func TestCacheOk(t *testing.T) { store, err := NewStore(1, 1, 1) require.NoError(t, err) - limiter := store.acquire("a") + limiter := store.ForIP(net.ParseIP("127.0.0.1")) assert.NotNil(t, limiter) } - -func TestRateLimitPerIP(t *testing.T) { - store, err := NewStore(2, 1, 1) - require.NoError(t, err) - - var reqFrom = func(ip string) *http.Request { - r := http.Request{} - r.Header = http.Header{} - r.Header.Set("X-Real-Ip", ip) - return &r - } - assert.True(t, store.ForIP(reqFrom("10.10.10.1")).Allow()) - assert.False(t, store.ForIP(reqFrom("10.10.10.1")).Allow()) - assert.True(t, store.ForIP(reqFrom("10.10.10.2")).Allow()) - assert.False(t, store.ForIP(reqFrom("10.10.10.3")).Allow()) -} diff --git a/beater/http.go b/beater/http.go index 994e4b396ed..d7bcb77bab5 100644 --- a/beater/http.go +++ b/beater/http.go @@ -29,6 +29,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/beater/api" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/model" "github.com/elastic/apm-server/model/modelprocessor" @@ -47,7 +48,16 @@ type httpServer struct { grpcListener net.Listener } -func newHTTPServer(logger *logp.Logger, info beat.Info, cfg *config.Config, tracer *apm.Tracer, reporter publish.Reporter, batchProcessor model.BatchProcessor, f agentcfg.Fetcher) (*httpServer, error) { +func newHTTPServer( + logger *logp.Logger, + info beat.Info, + cfg *config.Config, + tracer *apm.Tracer, + reporter publish.Reporter, + batchProcessor model.BatchProcessor, + agentcfgFetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, +) (*httpServer, error) { // Add a model processor that checks authorization for the agent and service for each event. batchProcessor = modelprocessor.Chained{ @@ -55,7 +65,7 @@ func newHTTPServer(logger *logp.Logger, info beat.Info, cfg *config.Config, trac batchProcessor, } - mux, err := api.NewMux(info, cfg, reporter, batchProcessor, f) + mux, err := api.NewMux(info, cfg, reporter, batchProcessor, agentcfgFetcher, ratelimitStore) if err != nil { return nil, err } diff --git a/beater/middleware/authorization_middleware.go b/beater/middleware/authorization_middleware.go index 2de8adb25f1..724b87881d4 100644 --- a/beater/middleware/authorization_middleware.go +++ b/beater/middleware/authorization_middleware.go @@ -65,6 +65,7 @@ func AnonymousAuthorizationMiddleware() Middleware { return func(h request.Handler) (request.Handler, error) { return func(c *request.Context) { auth := authorization.AnonymousAuth{} + c.AuthResult = authorization.Result{Authorized: true, Anonymous: true} c.Request = c.Request.WithContext(authorization.ContextWithAuthorization(c.Request.Context(), auth)) h(c) }, nil diff --git a/beater/middleware/rate_limit_middleware.go b/beater/middleware/rate_limit_middleware.go index 8b71f65cfe6..cb954432517 100644 --- a/beater/middleware/rate_limit_middleware.go +++ b/beater/middleware/rate_limit_middleware.go @@ -19,20 +19,33 @@ package middleware import ( "github.com/elastic/apm-server/beater/api/ratelimit" - "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/request" ) -const burstMultiplier = 3 - -// SetIPRateLimitMiddleware sets a rate limiter -func SetIPRateLimitMiddleware(cfg config.EventRate) Middleware { - store, err := ratelimit.NewStore(cfg.LruSize, cfg.Limit, burstMultiplier) - +// AnonymousRateLimitMiddleware adds a rate.Limiter to the context of anonymous +// requests, first ensuring the client is allowed to perform a single event and +// responding with 429 Too Many Requests if it is not. +// +// This middleware must be wrapped by AuthorizationMiddleware, as it depends on +// the value of c.AuthResult.Anonymous. +func AnonymousRateLimitMiddleware(store *ratelimit.Store) Middleware { return func(h request.Handler) (request.Handler, error) { return func(c *request.Context) { - c.RateLimiter = store.ForIP(c.Request) + if c.AuthResult.Anonymous { + limiter := store.ForIP(c.SourceIP) + if !limiter.Allow() { + c.Result.SetWithError( + request.IDResponseErrorsRateLimit, + ratelimit.ErrRateLimitExceeded, + ) + c.Write() + return + } + ctx := c.Request.Context() + ctx = ratelimit.ContextWithLimiter(ctx, limiter) + c.Request = c.Request.WithContext(ctx) + } h(c) - }, err + }, nil } } diff --git a/beater/middleware/rate_limit_middleware_test.go b/beater/middleware/rate_limit_middleware_test.go new file mode 100644 index 00000000000..bd6c9d18b89 --- /dev/null +++ b/beater/middleware/rate_limit_middleware_test.go @@ -0,0 +1,107 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/apm-server/beater/api/ratelimit" + "github.com/elastic/apm-server/beater/request" +) + +func TestAnonymousRateLimitMiddleware(t *testing.T) { + type test struct { + burst int + anonymous bool + + expectStatusCode int + expectAllow bool + } + for _, test := range []test{{ + burst: 0, + anonymous: false, + expectStatusCode: http.StatusOK, + }, { + burst: 0, + anonymous: true, + expectStatusCode: http.StatusTooManyRequests, + }, { + burst: 1, + anonymous: true, + expectStatusCode: http.StatusOK, + expectAllow: false, + }, { + burst: 2, + anonymous: true, + expectStatusCode: http.StatusOK, + expectAllow: true, + }} { + store, _ := ratelimit.NewStore(1, 1, test.burst) + middleware := AnonymousRateLimitMiddleware(store) + handler := func(c *request.Context) { + limiter, ok := ratelimit.FromContext(c.Request.Context()) + if test.anonymous { + require.True(t, ok) + assert.Equal(t, test.expectAllow, limiter.Allow()) + } else { + require.False(t, ok) + } + } + wrapped, err := middleware(handler) + require.NoError(t, err) + + c := request.NewContext() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + c.Reset(w, r) + c.AuthResult.Anonymous = test.anonymous + + wrapped(c) + assert.Equal(t, test.expectStatusCode, w.Code) + } +} + +func TestAnonymousRateLimitMiddlewareForIP(t *testing.T) { + store, _ := ratelimit.NewStore(2, 1, 1) + middleware := AnonymousRateLimitMiddleware(store) + handler := func(c *request.Context) {} + wrapped, err := middleware(handler) + require.NoError(t, err) + + requestWithIP := func(ip string) int { + c := request.NewContext() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.RemoteAddr = ip + c.Reset(w, r) + c.AuthResult.Anonymous = true + wrapped(c) + return w.Code + } + assert.Equal(t, http.StatusOK, requestWithIP("10.1.1.1")) + assert.Equal(t, http.StatusTooManyRequests, requestWithIP("10.1.1.1")) + assert.Equal(t, http.StatusOK, requestWithIP("10.1.1.2")) + + // ratelimit.Store size is 2: the 3rd IP reuses an existing (depleted) rate limiter. + assert.Equal(t, http.StatusTooManyRequests, requestWithIP("10.1.1.3")) +} diff --git a/beater/middleware/request_metadata_middleware.go b/beater/middleware/request_metadata_middleware.go deleted file mode 100644 index 9bb57d44808..00000000000 --- a/beater/middleware/request_metadata_middleware.go +++ /dev/null @@ -1,46 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package middleware - -import ( - "github.com/elastic/apm-server/beater/request" - "github.com/elastic/apm-server/utility" -) - -// UserMetadataMiddleware returns a Middleware recording request-level -// user metadata (e.g. user-agent and source IP) in the request's context. -func UserMetadataMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.RequestMetadata.UserAgent = utility.UserAgentHeader(c.Request.Header) - c.RequestMetadata.ClientIP = utility.ExtractIP(c.Request) - h(c) - }, nil - } -} - -// SystemMetadataMiddleware returns a Middleware recording request-level -// system metadata (e.g. source IP) in the request's context. -func SystemMetadataMiddleware() Middleware { - return func(h request.Handler) (request.Handler, error) { - return func(c *request.Context) { - c.RequestMetadata.SystemIP = utility.ExtractIP(c.Request) - h(c) - }, nil - } -} diff --git a/beater/middleware/request_metadata_middleware_test.go b/beater/middleware/request_metadata_middleware_test.go deleted file mode 100644 index ac69123c5c3..00000000000 --- a/beater/middleware/request_metadata_middleware_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package middleware - -import ( - "fmt" - "net" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/elastic/apm-server/beater/beatertest" -) - -func TestUserMetadataMiddleware(t *testing.T) { - type test struct { - remoteAddr string - userAgent []string - expectedIP net.IP - expectedUserAgent string - } - - ua1 := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36" - ua2 := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:67.0) Gecko/20100101 Firefox/67.0" - tests := []test{ - {remoteAddr: "1.2.3.4:1234", expectedIP: net.ParseIP("1.2.3.4"), userAgent: []string{ua1, ua2}, expectedUserAgent: fmt.Sprintf("%s, %s", ua1, ua2)}, - {remoteAddr: "not-an-ip:1234", userAgent: []string{ua1}, expectedUserAgent: ua1}, - {remoteAddr: ""}, - } - - for _, test := range tests { - c, _ := beatertest.DefaultContextWithResponseRecorder() - c.Request.RemoteAddr = test.remoteAddr - for _, ua := range test.userAgent { - c.Request.Header.Add("User-Agent", ua) - } - - Apply(UserMetadataMiddleware(), beatertest.HandlerIdle)(c) - assert.Equal(t, test.expectedUserAgent, c.RequestMetadata.UserAgent) - assert.Equal(t, test.expectedIP, c.RequestMetadata.ClientIP) - } -} - -func TestSystemMetadataMiddleware(t *testing.T) { - type test struct { - remoteAddr string - expectedIP net.IP - } - tests := []test{ - {remoteAddr: "1.2.3.4:1234", expectedIP: net.ParseIP("1.2.3.4")}, - {remoteAddr: "not-an-ip:1234"}, - {remoteAddr: ""}, - } - - for _, test := range tests { - c, _ := beatertest.DefaultContextWithResponseRecorder() - c.Request.RemoteAddr = test.remoteAddr - - Apply(SystemMetadataMiddleware(), beatertest.HandlerIdle)(c) - assert.Equal(t, test.expectedIP, c.RequestMetadata.SystemIP) - } -} diff --git a/beater/request/context.go b/beater/request/context.go index 118643ef412..4478ce9a8e3 100644 --- a/beater/request/context.go +++ b/beater/request/context.go @@ -23,13 +23,12 @@ import ( "net/http" "strings" - "golang.org/x/time/rate" - "github.com/elastic/beats/v7/libbeat/logp" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/headers" logs "github.com/elastic/apm-server/log" + "github.com/elastic/apm-server/utility" ) const ( @@ -43,26 +42,17 @@ var ( // Context abstracts request and response information for http requests type Context struct { - Request *http.Request - Logger *logp.Logger - RateLimiter *rate.Limiter - AuthResult authorization.Result - IsRum bool - Result Result - RequestMetadata Metadata + Request *http.Request + Logger *logp.Logger + AuthResult authorization.Result + Result Result + SourceIP net.IP + UserAgent string w http.ResponseWriter writeAttempts int } -// Metadata contains metadata extracted from the request by middleware, -// and should be merged into the event metadata. -type Metadata struct { - ClientIP net.IP - SystemIP net.IP - UserAgent string -} - // NewContext creates an empty Context struct func NewContext() *Context { return &Context{} @@ -72,23 +62,15 @@ func NewContext() *Context { func (c *Context) Reset(w http.ResponseWriter, r *http.Request) { c.Request = r c.Logger = nil - c.RateLimiter = nil c.AuthResult = authorization.Result{} - c.IsRum = false c.Result.Reset() - c.RequestMetadata.Reset() + c.SourceIP = utility.ExtractIP(r) + c.UserAgent = utility.UserAgentHeader(r.Header) c.w = w c.writeAttempts = 0 } -// Reset sets all attribtues of the Metadata instance to it's zero value -func (m *Metadata) Reset() { - m.ClientIP = nil - m.SystemIP = nil - m.UserAgent = "" -} - // Header returns the http.Header of the context's writer func (c *Context) Header() http.Header { return c.w.Header() diff --git a/beater/request/context_test.go b/beater/request/context_test.go index aa916ac5158..2ec5eb9e1ed 100644 --- a/beater/request/context_test.go +++ b/beater/request/context_test.go @@ -18,6 +18,7 @@ package request import ( + "net" "net/http" "net/http/httptest" "reflect" @@ -37,7 +38,11 @@ func TestContext_Reset(t *testing.T) { w1.WriteHeader(http.StatusServiceUnavailable) w2 := httptest.NewRecorder() r1 := httptest.NewRequest(http.MethodGet, "/", nil) + r1.RemoteAddr = "10.1.2.3:4321" + r1.Header.Set("User-Agent", "ua1") r2 := httptest.NewRequest(http.MethodHead, "/new", nil) + r2.RemoteAddr = "10.1.2.3:1234" + r2.Header.Set("User-Agent", "ua2") c := Context{ Request: r1, w: w1, @@ -65,8 +70,10 @@ func TestContext_Reset(t *testing.T) { assert.Equal(t, 0, c.writeAttempts) case "Result": assertResultIsEmpty(t, cVal.Field(i).Interface().(Result)) - case "RequestMetadata": - assert.Equal(t, Metadata{}, cVal.Field(i).Interface().(Metadata)) + case "SourceIP": + assert.Equal(t, net.ParseIP("10.1.2.3"), cVal.Field(i).Interface()) + case "UserAgent": + assert.Equal(t, "ua2", cVal.Field(i).Interface()) default: assert.Empty(t, cVal.Field(i).Interface(), cType.Field(i).Name) } diff --git a/beater/server.go b/beater/server.go index 7d144321402..11179ebb31b 100644 --- a/beater/server.go +++ b/beater/server.go @@ -32,6 +32,7 @@ import ( "github.com/elastic/beats/v7/libbeat/version" "github.com/elastic/apm-server/agentcfg" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/authorization" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/beater/interceptors" @@ -116,16 +117,24 @@ func newServer( reporter publish.Reporter, batchProcessor model.BatchProcessor, ) (server, error) { - fetcher := agentcfg.NewFetcher(cfg) - httpServer, err := newHTTPServer(logger, info, cfg, tracer, reporter, batchProcessor, fetcher) + agentcfgFetcher := agentcfg.NewFetcher(cfg) + ratelimitStore, err := ratelimit.NewStore( + cfg.RumConfig.EventRate.LruSize, + cfg.RumConfig.EventRate.Limit, + 3, // burst multiplier + ) if err != nil { return server{}, err } - grpcServer, err := newGRPCServer(logger, cfg, tracer, batchProcessor, httpServer.TLSConfig, fetcher) + httpServer, err := newHTTPServer(logger, info, cfg, tracer, reporter, batchProcessor, agentcfgFetcher, ratelimitStore) if err != nil { return server{}, err } - jaegerServer, err := jaeger.NewServer(logger, cfg, tracer, batchProcessor, fetcher) + grpcServer, err := newGRPCServer(logger, cfg, tracer, batchProcessor, httpServer.TLSConfig, agentcfgFetcher, ratelimitStore) + if err != nil { + return server{}, err + } + jaegerServer, err := jaeger.NewServer(logger, cfg, tracer, batchProcessor, agentcfgFetcher) if err != nil { return server{}, err } @@ -144,7 +153,8 @@ func newGRPCServer( tracer *apm.Tracer, batchProcessor model.BatchProcessor, tlsConfig *tls.Config, - fetcher agentcfg.Fetcher, + agentcfgFetcher agentcfg.Fetcher, + ratelimitStore *ratelimit.Store, ) (*grpc.Server, error) { // TODO(axw) share auth builder with beater/api. authBuilder, err := authorization.NewBuilder(cfg.AgentAuth) @@ -152,10 +162,11 @@ func newGRPCServer( return nil, err } - // NOTE(axw) even if TLS is enabled we should not use grpc.Creds, as TLS is handled by the net/http server. apmInterceptor := apmgrpc.NewUnaryServerInterceptor(apmgrpc.WithRecovery(), apmgrpc.WithTracer(tracer)) authInterceptor := newAuthUnaryServerInterceptor(authBuilder) + // Note that we intentionally do not use a grpc.Creds ServerOption + // even if TLS is enabled, as TLS is handled by the net/http server. logger = logger.Named("grpc") srv := grpc.NewServer( grpc.ChainUnaryInterceptor( @@ -165,6 +176,9 @@ func newGRPCServer( interceptors.Metrics(logger, otlp.RegistryMonitoringMaps, jaeger.RegistryMonitoringMaps), interceptors.Timeout(), authInterceptor, + + // TODO(axw) add a rate limiting interceptor here once we've + // updated authInterceptor to handle auth for Jaeger requests. ), ) @@ -182,7 +196,7 @@ func newGRPCServer( batchProcessor, } - jaeger.RegisterGRPCServices(srv, authBuilder, jaeger.ElasticAuthTag, logger, batchProcessor, fetcher) + jaeger.RegisterGRPCServices(srv, authBuilder, jaeger.ElasticAuthTag, logger, batchProcessor, agentcfgFetcher) if err := otlp.RegisterGRPCServices(srv, batchProcessor); err != nil { return nil, err } diff --git a/beater/tracing.go b/beater/tracing.go index 8d37f979ea5..6a9f791795b 100644 --- a/beater/tracing.go +++ b/beater/tracing.go @@ -27,6 +27,7 @@ import ( "github.com/elastic/apm-server/agentcfg" "github.com/elastic/apm-server/beater/api" + "github.com/elastic/apm-server/beater/api/ratelimit" "github.com/elastic/apm-server/beater/config" "github.com/elastic/apm-server/model" "github.com/elastic/apm-server/publish" @@ -59,7 +60,18 @@ func newTracerServer(listener net.Listener, logger *logp.Logger) (*tracerServer, } }) cfg := config.DefaultConfig() - mux, err := api.NewMux(beat.Info{}, cfg, nopReporter, processBatch, agentcfg.NewFetcher(cfg)) + ratelimitStore, err := ratelimit.NewStore(1, 1, 1) // unused, arbitrary params + if err != nil { + return nil, err + } + mux, err := api.NewMux( + beat.Info{}, + cfg, + nopReporter, + processBatch, + agentcfg.NewFetcher(cfg), + ratelimitStore, + ) if err != nil { return nil, err } diff --git a/systemtest/rum_test.go b/systemtest/rum_test.go index 7f72a3ea6d2..4d82b0a71ef 100644 --- a/systemtest/rum_test.go +++ b/systemtest/rum_test.go @@ -192,7 +192,12 @@ func TestRUMRateLimit(t *testing.T) { g.Go(func() error { return sendEvents("10.11.12.13", srv.Config.RUM.RateLimit.EventLimit) }) g.Go(func() error { return sendEvents("10.11.12.14", srv.Config.RUM.RateLimit.EventLimit) }) g.Go(func() error { return sendEvents("10.11.12.15", srv.Config.RUM.RateLimit.EventLimit) }) - assert.EqualError(t, g.Wait(), `429 Too Many Requests ({"accepted":0,"errors":[{"message":"rate limit exceeded"}]})`) + err = g.Wait() + require.Error(t, err) + + // The exact error differs, depending on whether rate limiting was applied at the request + // level, or at the event stream level. Either could occur. + assert.Regexp(t, `429 Too Many Requests .*`, err.Error()) } func sendRUMEventsPayload(t *testing.T, srv *apmservertest.Server, payloadFile string) {