From eed4c2a2a86620825e698b96556bd0bfc7d3d312 Mon Sep 17 00:00:00 2001 From: Chris Berkhout Date: Thu, 5 Dec 2024 09:58:22 +0100 Subject: [PATCH] x-pack/filebeat/input/entityanalytics/provider/okta: Rate limiting fixes (#41583) - Separate rate limits by endpoint. - Stop requests until reset when `x-rate-limit-remaining: 0`. (cherry picked from commit 4e19d09ab2e7c908f16b0d9bed1eee994ee8c744) --- CHANGELOG.next.asciidoc | 1 + .../provider/okta/internal/okta/okta.go | 148 +++++++----------- .../provider/okta/internal/okta/okta_test.go | 37 +++-- .../okta/internal/okta/ratelimiter.go | 97 ++++++++++++ .../okta/internal/okta/ratelimiter_test.go | 84 ++++++++++ .../entityanalytics/provider/okta/okta.go | 5 +- .../provider/okta/okta_test.go | 4 +- 7 files changed, 264 insertions(+), 112 deletions(-) create mode 100644 x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter.go create mode 100644 x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter_test.go diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index a37132bc6de..264097af201 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -183,6 +183,7 @@ https://github.com/elastic/beats/compare/v8.8.1\...main[Check the HEAD diff] - Fix the "No such input type exist: 'salesforce'" error on the Windows/AIX platform. {pull}41664[41664] - Improve S3 object size metric calculation to support situations where Content-Length is not available. {pull}41755[41755] - Fix handling of http_endpoint request exceeding memory limits. {issue}41764[41764] {pull}41765[41765] +- Rate limiting fixes in the Okta provider of the Entity Analytics input. {issue}40106[40106] {pull}41583[41583] *Heartbeat* diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta.go b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta.go index 3d8bdae11c9..42ffc060178 100644 --- a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta.go +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta.go @@ -14,13 +14,9 @@ import ( "io" "net/http" "net/url" - "path" - "strconv" "strings" "time" - "golang.org/x/time/rate" - "github.com/elastic/elastic-agent-libs/logp" ) @@ -195,16 +191,23 @@ func (o Response) String() string { // https://${yourOktaDomain}/reports/rate-limit. // // See https://developer.okta.com/docs/reference/api/users/#list-users for details. -func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) { - const endpoint = "/api/v1/users" +func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) { + var endpoint, path string + if user == "" { + endpoint = "/api/v1/users" + path = endpoint + } else { + endpoint = "/api/v1/users/{user}" + path = strings.Replace(endpoint, "{user}", user, 1) + } u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, user), + Path: path, RawQuery: query.Encode(), } - return getDetails[User](ctx, cli, u, key, user == "", omit, lim, window, log) + return getDetails[User](ctx, cli, u, endpoint, key, user == "", omit, lim, window, log) } // GetUserFactors returns Okta group roles using the groups API endpoint. host is the @@ -213,19 +216,20 @@ func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user strin // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/UserFactor/#tag/UserFactor/operation/listFactors. -func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Factor, http.Header, error) { - const endpoint = "/api/v1/users" - +func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Factor, http.Header, error) { if user == "" { return nil, nil, errors.New("no user specified") } + const endpoint = "/api/v1/users/{user}/factors" + path := strings.Replace(endpoint, "{user}", user, 1) + u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, user, "factors"), + Path: path, } - return getDetails[Factor](ctx, cli, u, key, true, OmitNone, lim, window, log) + return getDetails[Factor](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log) } // GetUserRoles returns Okta group roles using the groups API endpoint. host is the @@ -234,19 +238,20 @@ func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user strin // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/RoleAssignmentBGroup/#tag/RoleAssignmentBGroup/operation/listGroupAssignedRoles. -func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) { - const endpoint = "/api/v1/users" - +func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) { if user == "" { return nil, nil, errors.New("no user specified") } + const endpoint = "/api/v1/users/{user}/roles" + path := strings.Replace(endpoint, "{user}", user, 1) + u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, user, "roles"), + Path: path, } - return getDetails[Role](ctx, cli, u, key, true, OmitNone, lim, window, log) + return getDetails[Role](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log) } // GetUserGroupDetails returns Okta group details using the users API endpoint. host is the @@ -255,19 +260,20 @@ func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string, // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/reference/api/users/#request-parameters-8 (no anchor exists on the page for this endpoint) for details. -func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Group, http.Header, error) { - const endpoint = "/api/v1/users" - +func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Group, http.Header, error) { if user == "" { return nil, nil, errors.New("no user specified") } + const endpoint = "/api/v1/users/{user}/groups" + path := strings.Replace(endpoint, "{user}", user, 1) + u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, user, "groups"), + Path: path, } - return getDetails[Group](ctx, cli, u, key, true, OmitNone, lim, window, log) + return getDetails[Group](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log) } // GetGroupRoles returns Okta group roles using the groups API endpoint. host is the @@ -276,19 +282,20 @@ func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/RoleAssignmentBGroup/#tag/RoleAssignmentBGroup/operation/listGroupAssignedRoles. -func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) { - const endpoint = "/api/v1/groups" - +func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) { if group == "" { return nil, nil, errors.New("no group specified") } + const endpoint = "/api/v1/groups/{group}/rules" + path := strings.Replace(endpoint, "{group}", group, 1) + u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, group, "roles"), + Path: path, } - return getDetails[Role](ctx, cli, u, key, true, OmitNone, lim, window, log) + return getDetails[Role](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log) } // GetDeviceDetails returns Okta device details using the list devices API endpoint. host is the @@ -298,16 +305,24 @@ func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group strin // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/Device/#tag/Device/operation/listDevices for details. -func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Device, http.Header, error) { - const endpoint = "/api/v1/devices" +func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Device, http.Header, error) { + var endpoint string + var path string + if device == "" { + endpoint = "/api/v1/devices" + path = endpoint + } else { + endpoint = "/api/v1/devices/{device}" + path = strings.Replace(endpoint, "{device}", device, 1) + } u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, device), + Path: path, RawQuery: query.Encode(), } - return getDetails[Device](ctx, cli, u, key, device == "", OmitNone, lim, window, log) + return getDetails[Device](ctx, cli, u, endpoint, key, device == "", OmitNone, lim, window, log) } // GetDeviceUsers returns Okta user details for users associated with the provided device identifier @@ -317,21 +332,22 @@ func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device s // See GetUserDetails for details of the query and rate limit parameters. // // See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/Device/#tag/Device/operation/listDeviceUsers for details. -func GetDeviceUsers(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) { +func GetDeviceUsers(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) { if device == "" { // No user associated with a null device. Not an error. return nil, nil, nil } - const endpoint = "/api/v1/devices" + const endpoint = "/api/v1/devices/{device}/users" + path := strings.Replace(endpoint, "{device}", device, 1) u := &url.URL{ Scheme: "https", Host: host, - Path: path.Join(endpoint, device, "users"), + Path: path, RawQuery: query.Encode(), } - du, h, err := getDetails[devUser](ctx, cli, u, key, true, omit, lim, window, log) + du, h, err := getDetails[devUser](ctx, cli, u, endpoint, key, true, omit, lim, window, log) if err != nil { return nil, h, err } @@ -356,7 +372,7 @@ type devUser struct { // for the specific user are returned, otherwise a list of all users is returned. // // See GetUserDetails for details of the query and rate limit parameters. -func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key string, all bool, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]E, http.Header, error) { +func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, endpoint string, key string, all bool, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]E, http.Header, error) { url := u.String() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -370,8 +386,7 @@ func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key req.Header.Set("Content-Type", contentType) req.Header.Set("Authorization", fmt.Sprintf("SSWS %s", key)) - log.Debugw("rate limit", "limit", lim.Limit(), "burst", lim.Burst(), "url", url) - err = lim.Wait(ctx) + err = lim.Wait(ctx, endpoint, u, log) if err != nil { return nil, nil, err } @@ -380,7 +395,7 @@ func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key return nil, nil, err } defer resp.Body.Close() - err = oktaRateLimit(resp.Header, window, lim, log) + err = lim.Update(endpoint, resp.Header, window, log) if err != nil { io.Copy(io.Discard, resp.Body) return nil, nil, err @@ -443,59 +458,6 @@ func (e *Error) Error() string { return fmt.Sprintf("%s: %s", summary, strings.Join(causes, ",")) } -// oktaRateLimit implements the Okta rate limit policy translation. -// -// See https://developer.okta.com/docs/reference/rl-best-practices/ for details. -func oktaRateLimit(h http.Header, window time.Duration, limiter *rate.Limiter, log *logp.Logger) error { - limit := h.Get("X-Rate-Limit-Limit") - remaining := h.Get("X-Rate-Limit-Remaining") - reset := h.Get("X-Rate-Limit-Reset") - log.Debugw("rate limit header", "X-Rate-Limit-Limit", limit, "X-Rate-Limit-Remaining", remaining, "X-Rate-Limit-Reset", reset) - if limit == "" || remaining == "" || reset == "" { - return nil - } - - lim, err := strconv.ParseFloat(limit, 64) - if err != nil { - return err - } - rem, err := strconv.ParseFloat(remaining, 64) - if err != nil { - return err - } - rst, err := strconv.ParseInt(reset, 10, 64) - if err != nil { - return err - } - resetTime := time.Unix(rst, 0) - per := time.Until(resetTime).Seconds() - - // Be conservative here; the docs don't exactly specify burst rates. - // Make sure we can make at least one new request, even if we fail - // to get a non-zero rate.Limit. We could set to zero for the case - // that limit=rate.Inf, but that detail is not important. - burst := 1 - - rateLimit := rate.Limit(rem / per) - - // Process reset if we need to wait until reset to avoid a request against a zero quota. - if rateLimit <= 0 { - waitUntil := resetTime.UTC() - // next gives us a sane next window estimate, but the - // estimate will be overwritten when we make the next - // permissible API request. - next := rate.Limit(lim / window.Seconds()) - limiter.SetLimitAt(waitUntil, next) - limiter.SetBurstAt(waitUntil, burst) - log.Debugw("rate limit adjust", "reset_time", waitUntil, "next_rate", next, "next_burst", burst) - return nil - } - limiter.SetLimit(rateLimit) - limiter.SetBurst(burst) - log.Debugw("rate limit adjust", "set_rate", rateLimit, "set_burst", burst) - return nil -} - // Next returns the next URL query for a pagination sequence. If no further // page is available, Next returns io.EOF. func Next(h http.Header) (query url.Values, err error) { diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta_test.go b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta_test.go index 9b04d3996bf..45b1b2a4ca4 100644 --- a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta_test.go +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/okta_test.go @@ -44,8 +44,8 @@ func Test(t *testing.T) { t.Skip("okta tests require ${OKTA_TOKEN} to be set") } - // Make a global limiter with the capacity to proceed once. - limiter := rate.NewLimiter(1, 1) + // Make a global limiter + limiter := NewRateLimiter() // There are a variety of windows, the most conservative is one minute. // The rate limit will be adjusted on the second call to the API if @@ -263,14 +263,14 @@ var localTests = []struct { name string msg string id string - fn func(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) (any, http.Header, error) + fn func(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) (any, http.Header, error) mkWant func(string) (any, error) }{ { // Test case constructed from API-returned value with details anonymised. name: "users", msg: `[{"id":"userid","status":"STATUS","created":"2023-05-14T13:37:20.000Z","activated":null,"statusChanged":"2023-05-15T01:50:30.000Z","lastLogin":"2023-05-15T01:59:20.000Z","lastUpdated":"2023-05-15T01:50:32.000Z","passwordChanged":"2023-05-15T01:50:32.000Z","recovery_question":{"question":"Who's a major player in the cowboy scene?","answer":"Annie Oakley"},"type":{"id":"typeid"},"profile":{"firstName":"name","lastName":"surname","mobilePhone":null,"secondEmail":null,"login":"name.surname@example.com","email":"name.surname@example.com"},"credentials":{"password":{"value":"secret"},"emails":[{"value":"name.surname@example.com","status":"VERIFIED","type":"PRIMARY"}],"provider":{"type":"OKTA","name":"OKTA"}},"_links":{"self":{"href":"https://localhost/api/v1/users/userid"}}}]`, - fn: func(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { + fn: func(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { return GetUserDetails(context.Background(), cli, host, key, user, query, OmitNone, lim, window, log) }, mkWant: mkWant[User], @@ -279,7 +279,7 @@ var localTests = []struct { // Test case from https://developer.okta.com/docs/api/openapi/okta-management/management/tag/Device/#tag/Device/operation/listDevices name: "devices", msg: `[{"id":"devid","status":"CREATED","created":"2019-10-02T18:03:07.000Z","lastUpdated":"2019-10-02T18:03:07.000Z","profile":{"displayName":"Example Device name 1","platform":"WINDOWS","serialNumber":"XXDDRFCFRGF3M8MD6D","sid":"S-1-11-111","registered":true,"secureHardwarePresent":false,"diskEncryptionType":"ALL_INTERNAL_VOLUMES"},"resourceType":"UDDevice","resourceDisplayName":{"value":"Example Device name 1","sensitive":false},"resourceAlternateId":null,"resourceId":"guo4a5u7YAHhjXrMK0g4","_links":{"activate":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g4/lifecycle/activate","hints":{"allow":["POST"]}},"self":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g4","hints":{"allow":["GET","PATCH","PUT"]}},"users":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g4/users","hints":{"allow":["GET"]}}}},{"id":"guo4a5u7YAHhjXrMK0g5","status":"ACTIVE","created":"2023-06-21T23:24:02.000Z","lastUpdated":"2023-06-21T23:24:02.000Z","profile":{"displayName":"Example Device name 2","platform":"ANDROID","manufacturer":"Google","model":"Pixel 6","osVersion":"13:2023-05-05","registered":true,"secureHardwarePresent":true,"diskEncryptionType":"USER"},"resourceType":"UDDevice","resourceDisplayName":{"value":"Example Device name 2","sensitive":false},"resourceAlternateId":null,"resourceId":"guo4a5u7YAHhjXrMK0g5","_links":{"activate":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g5/lifecycle/activate","hints":{"allow":["POST"]}},"self":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g5","hints":{"allow":["GET","PATCH","PUT"]}},"users":{"href":"https://{yourOktaDomain}/api/v1/devices/guo4a5u7YAHhjXrMK0g5/users","hints":{"allow":["GET"]}}}}]`, - fn: func(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { + fn: func(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { return GetDeviceDetails(context.Background(), cli, host, key, device, query, lim, window, log) }, mkWant: mkWant[Device], @@ -289,7 +289,7 @@ var localTests = []struct { name: "devices_users", msg: `[{"created":"2023-08-07T21:48:27.000Z","managementStatus":"NOT_MANAGED","user":{"id":"userid","status":"STATUS","created":"2023-05-14T13:37:20.000Z","activated":null,"statusChanged":"2023-05-15T01:50:30.000Z","lastLogin":"2023-05-15T01:59:20.000Z","lastUpdated":"2023-05-15T01:50:32.000Z","passwordChanged":"2023-05-15T01:50:32.000Z","type":{"id":"typeid"},"profile":{"firstName":"name","lastName":"surname","mobilePhone":null,"secondEmail":null,"login":"name.surname@example.com","email":"name.surname@example.com"},"credentials":{"password":{"value":"secret"},"recovery_question":{"question":"Who's a major player in the cowboy scene?","answer":"Annie Oakley"},"emails":[{"value":"name.surname@example.com","status":"VERIFIED","type":"PRIMARY"}],"provider":{"type":"OKTA","name":"OKTA"}},"_links":{"self":{"href":"https://localhost/api/v1/users/userid"}}}}]`, id: "devid", - fn: func(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { + fn: func(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) (any, http.Header, error) { return GetDeviceUsers(context.Background(), cli, host, key, device, query, OmitNone, lim, window, log) }, mkWant: mkWant[devUser], @@ -315,9 +315,7 @@ func TestLocal(t *testing.T) { for _, test := range localTests { t.Run(test.name, func(t *testing.T) { - // Make a global limiter with more capacity than will be set by the mock API. - // This will show the burst drop. - limiter := rate.NewLimiter(10, 10) + limiter := NewRateLimiter() // There are a variety of windows, the most conservative is one minute. // The rate limit will be adjusted on the second call to the API if @@ -377,12 +375,23 @@ func TestLocal(t *testing.T) { t.Errorf("unexpected result:\n- want\n+ got\n%s", cmp.Diff(want, got)) } - lim := limiter.Limit() - if lim < 49.0/60.0 || 50.0/60.0 < lim { - t.Errorf("unexpected rate limit (outside [49/60, 50/60]: %f", lim) + if len(limiter) != 1 { + t.Errorf("unexpected number endpoints track by rate limiter: %d", len(limiter)) } - if limiter.Burst() != 1 { // Set in GetUserDetails. - t.Errorf("unexpected burst: got:%d want:1", limiter.Burst()) + // retrieve the rate.Limiter parameters for the one endpoint + var limit rate.Limit + var burst int + for _, l := range limiter { + limit = l.Limit() + burst = l.Burst() + break + } + + if limit < 49.0/60.0 || 50.0/60.0 < limit { + t.Errorf("unexpected rate limit (outside [49/60, 50/60]: %f", limit) + } + if burst != 1 { + t.Errorf("unexpected burst: got:%d want:1", burst) } next, err := Next(h) diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter.go b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter.go new file mode 100644 index 00000000000..1b58e01328c --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter.go @@ -0,0 +1,97 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package okta + +import ( + "context" + "net/http" + "net/url" + "strconv" + "time" + + "golang.org/x/time/rate" + + "github.com/elastic/elastic-agent-libs/logp" +) + +type RateLimiter map[string]*rate.Limiter + +func NewRateLimiter() RateLimiter { + r := make(RateLimiter) + return r +} + +func (r RateLimiter) limiter(path string) *rate.Limiter { + if existing, ok := r[path]; ok { + return existing + } + initial := rate.NewLimiter(1, 1) // Allow a single fetch operation to obtain limits from the API + r[path] = initial + return initial +} + +func (r RateLimiter) Wait(ctx context.Context, endpoint string, url *url.URL, log *logp.Logger) (err error) { + limiter := r.limiter(endpoint) + log.Debugw("rate limit", "limit", limiter.Limit(), "burst", limiter.Burst(), "url", url.String()) + return limiter.Wait(ctx) +} + +// Update implements the Okta rate limit policy translation. +// +// See https://developer.okta.com/docs/reference/rl-best-practices/ for details. +func (r RateLimiter) Update(endpoint string, h http.Header, window time.Duration, log *logp.Logger) error { + limiter := r.limiter(endpoint) + limit := h.Get("X-Rate-Limit-Limit") + remaining := h.Get("X-Rate-Limit-Remaining") + reset := h.Get("X-Rate-Limit-Reset") + log.Debugw("rate limit header", "X-Rate-Limit-Limit", limit, "X-Rate-Limit-Remaining", remaining, "X-Rate-Limit-Reset", reset) + if limit == "" || remaining == "" || reset == "" { + return nil + } + + lim, err := strconv.ParseFloat(limit, 64) + if err != nil { + return err + } + rem, err := strconv.ParseFloat(remaining, 64) + if err != nil { + return err + } + rst, err := strconv.ParseInt(reset, 10, 64) + if err != nil { + return err + } + resetTime := time.Unix(rst, 0) + per := time.Until(resetTime).Seconds() + + // Be conservative here; the docs don't exactly specify burst rates. + // Make sure we can make at least one new request, even if we fail + // to get a non-zero rate.Limit. We could set to zero for the case + // that limit=rate.Inf, but that detail is not important. + burst := 1 + + rateLimit := rate.Limit(rem / per) + + // Process reset if we need to wait until reset to avoid a request against a zero quota. + if rateLimit <= 0 { + // Reset limiter to block requests until reset + limiter := rate.NewLimiter(0, 0) + r[endpoint] = limiter + + // next gives us a sane next window estimate, but the + // estimate will be overwritten when we make the next + // permissible API request. + next := rate.Limit(lim / window.Seconds()) + waitUntil := resetTime.UTC() + limiter.SetLimitAt(waitUntil, next) + limiter.SetBurstAt(waitUntil, burst) + log.Debugw("rate limit reset", "reset_time", waitUntil, "next_rate", next, "next_burst", burst) + return nil + } + limiter.SetLimit(rateLimit) + limiter.SetBurst(burst) + log.Debugw("rate limit adjust", "set_rate", rateLimit, "set_burst", burst) + return nil +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter_test.go b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter_test.go new file mode 100644 index 00000000000..1492e55c8a6 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta/ratelimiter_test.go @@ -0,0 +1,84 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package okta + +import ( + "net/http" + "strconv" + "testing" + "time" + + "github.com/elastic/elastic-agent-libs/logp" +) + +func TestRateLimiter(t *testing.T) { + logp.TestingSetup() + + t.Run("separation by endpoint", func(t *testing.T) { + r := NewRateLimiter() + limiter1 := r.limiter("/foo") + limiter2 := r.limiter("/bar") + + limiter1.SetBurst(1000) + + if limiter2.Burst() == 1000 { + t.Errorf("changes to one endpoint's limits affected another") + } + }) + + t.Run("Update stops requests when none are remaining", func(t *testing.T) { + r := NewRateLimiter() + + const endpoint = "/foo" + limiter := r.limiter(endpoint) + + if !limiter.Allow() { + t.Errorf("doesn't allow an initial request") + } + + now := time.Now().Unix() + reset := now + 30 + + headers := http.Header{ + "X-Rate-Limit-Limit": []string{"60"}, + "X-Rate-Limit-Remaining": []string{"0"}, + "X-Rate-Limit-Reset": []string{strconv.FormatInt(reset, 10)}, + } + window := time.Minute + + err := r.Update(endpoint, headers, window, logp.L()) + if err != nil { + t.Errorf("unexpected error from Update(): %v", err) + } + limiter = r.limiter(endpoint) + + if limiter.Allow() { + t.Errorf("allowed a request when none are remaining") + } + + if limiter.AllowN(time.Unix(reset-1, 999999999), 1) { + t.Errorf("allowed a request before reset, when none are remaining") + } + + if !limiter.AllowN(time.Unix(reset+1, 0), 1) { + t.Errorf("doesn't allow requests to resume after reset") + } + + if limiter.Limit() != 1.0 { + t.Errorf("unexpected rate following reset (not 60 requests / 60 seconds): %f", limiter.Limit()) + } + + if limiter.Burst() != 1 { + t.Errorf("unexpected burst following reset (not 1): %d", limiter.Burst()) + } + + limiter.SetBurstAt(time.Unix(reset, 0), 100) // increase bucket size to check token accumulation + tokens := limiter.TokensAt(time.Unix(reset+30, 0)) + if tokens < 29.5 || tokens > 30.0 { + t.Errorf("tokens don't accumulate at the expected rate. tokens 30s after reset: %f", tokens) + } + + }) +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/okta.go b/x-pack/filebeat/input/entityanalytics/provider/okta/okta.go index 5d68cf3f5c4..30103d3ccdb 100644 --- a/x-pack/filebeat/input/entityanalytics/provider/okta/okta.go +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/okta.go @@ -23,7 +23,6 @@ import ( "go.elastic.co/ecszap" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "golang.org/x/time/rate" v2 "github.com/elastic/beats/v7/filebeat/input/v2" "github.com/elastic/beats/v7/libbeat/beat" @@ -60,7 +59,7 @@ type oktaInput struct { cfg conf client *http.Client - lim *rate.Limiter + lim okta.RateLimiter metrics *inputMetrics logger *logp.Logger @@ -111,7 +110,7 @@ func (p *oktaInput) Run(inputCtx v2.Context, store *kvstore.Store, client beat.C updateTimer := time.NewTimer(updateWaitTime) // Allow a single fetch operation to obtain limits from the API. - p.lim = rate.NewLimiter(1, 1) + p.lim = okta.NewRateLimiter() if p.cfg.Tracer != nil { id := sanitizeFileName(inputCtx.IDWithoutName) diff --git a/x-pack/filebeat/input/entityanalytics/provider/okta/okta_test.go b/x-pack/filebeat/input/entityanalytics/provider/okta/okta_test.go index 5752370c4ce..e7e2bffbba2 100644 --- a/x-pack/filebeat/input/entityanalytics/provider/okta/okta_test.go +++ b/x-pack/filebeat/input/entityanalytics/provider/okta/okta_test.go @@ -18,7 +18,6 @@ import ( "testing" "time" - "golang.org/x/time/rate" "gopkg.in/natefinch/lumberjack.v2" "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/provider/okta/internal/okta" @@ -177,6 +176,7 @@ func TestOktaDoFetch(t *testing.T) { if err != nil { t.Errorf("failed to parse server URL: %v", err) } + rateLimiter := okta.NewRateLimiter() a := oktaInput{ cfg: conf{ OktaDomain: u.Host, @@ -185,7 +185,7 @@ func TestOktaDoFetch(t *testing.T) { EnrichWith: test.enrichWith, }, client: ts.Client(), - lim: rate.NewLimiter(1, 1), + lim: rateLimiter, logger: logp.L(), } if *trace {