Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of GitHub ratelimit information #2041

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ require (
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/go-cmp v0.6.0
github.com/google/go-containerregistry v0.17.0
github.com/google/go-github/v42 v42.0.0
github.com/google/go-github/v57 v57.0.0
github.com/google/uuid v1.5.0
github.com/googleapis/gax-go/v2 v2.12.0
github.com/h2non/filetype v1.1.3
Expand Down Expand Up @@ -174,7 +174,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v23.1.21+incompatible // indirect
github.com/google/go-github/v57 v57.0.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.7 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-containerregistry v0.17.0 h1:5p+zYs/R4VGHkhyvgWurWrpJ2hW4Vv9fQI+GzdcwXLk=
github.com/google/go-containerregistry v0.17.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ=
github.com/google/go-github/v42 v42.0.0 h1:YNT0FwjPrEysRkLIiKuEfSvBPCGKphW5aS5PxwaoLec=
github.com/google/go-github/v42 v42.0.0/go.mod h1:jgg/jvyI0YlDOM1/ps6XYh04HNQ3vKf0CVko62/EhRg=
github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs=
github.com/google/go-github/v57 v57.0.0/go.mod h1:s0omdnye0hvK/ecLvpsGfJMiRt85PimQh4oygmLIxHw=
github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
Expand Down
5 changes: 3 additions & 2 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ import (
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
"github.com/google/go-github/v42/github"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"github.com/google/go-github/v57/github"
"golang.org/x/oauth2"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"

"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
Expand Down
157 changes: 81 additions & 76 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
"sync/atomic"
"time"

"golang.org/x/exp/rand"

"github.com/bradleyfalzon/ghinstallation/v2"
"github.com/go-logr/logr"
"github.com/gobwas/glob"
"github.com/google/go-github/v42/github"
"github.com/google/go-github/v57/github"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -388,7 +390,6 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
return
}

var resp *github.Response
urlPathParts := strings.Split(u.Path, "/")
switch len(urlPathParts) {
case 2:
Expand All @@ -397,8 +398,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
repoName := urlPathParts[1]
repoName = strings.TrimSuffix(repoName, ".git")
for {
gist, resp, err = s.apiClient.Gists.Get(ctx, repoName)
if !s.handleRateLimit(err, resp) {
gist, _, err = s.apiClient.Gists.Get(ctx, repoName)
if !s.handleRateLimit(err) {
break
}
}
Expand All @@ -415,8 +416,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
repoName := urlPathParts[2]
repoName = strings.TrimSuffix(repoName, ".git")
for {
repo, resp, err = s.apiClient.Repositories.Get(ctx, owner, repoName)
if !s.handleRateLimit(err, resp) {
repo, _, err = s.apiClient.Repositories.Get(ctx, owner, repoName)
if !s.handleRateLimit(err) {
break
}
}
Expand Down Expand Up @@ -584,13 +585,12 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri

var (
ghUser *github.User
resp *github.Response
)

ctx.Logger().V(1).Info("Enumerating with token", "endpoint", apiEndpoint)
for {
ghUser, resp, err = s.apiClient.Users.Get(ctx, "")
if handled := s.handleRateLimit(err, resp); handled {
ghUser, _, err = s.apiClient.Users.Get(ctx, "")
if s.handleRateLimit(err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -696,7 +696,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
// Does this need to be separate from |s.httpClient|?
instHTTPClient := common.RetryableHttpClientTimeout(60)
instHTTPClient.Transport = appItr
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHTTPClient)
installationClient, err = github.NewClient(instHTTPClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
if err != nil {
return nil, err
}
Expand All @@ -713,7 +713,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
itr.BaseURL = apiEndpoint

s.httpClient.Transport = itr
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient)
s.apiClient, err = github.NewClient(s.httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -753,7 +753,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
return github.NewClient(httpClient), nil
}

return github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient)
return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
}

func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
Expand Down Expand Up @@ -869,53 +869,70 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, re
return duration, nil
}

// handleRateLimit returns true if a rate limit was handled
// Unauthenticated access to most github endpoints has a rate limit of 60 requests per hour.
// This will likely only be exhausted if many users/orgs are scanned without auth
func (s *Source) handleRateLimit(errIn error, res *github.Response) bool {
var (
knownWait = true
remaining = 0
retryAfter time.Duration
)
var (
rateLimitMu sync.RWMutex
rateLimitResumeTime time.Time
)

// GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors.
var rateLimit *github.RateLimitError
var abuseLimit *github.AbuseRateLimitError
if errors.As(errIn, &rateLimit) {
// Do nothing
} else if errors.As(errIn, &abuseLimit) {
retryAfter = abuseLimit.GetRetryAfter()
} else {
// handleRateLimit returns true if a rate limit was handled
//
// Unauthenticated users have a rate limit of 60 requests per hour.
// Authenticated users have a rate limit of 5,000 requests per hour,
// however, certain actions are subject to a stricter "secondary" limit.
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
func (s *Source) handleRateLimit(errIn error) bool {
rgmz marked this conversation as resolved.
Show resolved Hide resolved
if errIn == nil {
return false
}

githubNumRateLimitEncountered.WithLabelValues(s.name).Inc()
// Parse retry information from response headers, unless a Retry-After value was already provided.
// https://docs.github.com/en/rest/overview/resources-in-the-rest-api#exceeding-the-rate-limit
if retryAfter <= 0 && res != nil {
var err error
remaining, err = strconv.Atoi(res.Header.Get("x-ratelimit-remaining"))
if err != nil {
knownWait = false
rateLimitMu.RLock()
resumeTime := rateLimitResumeTime
rateLimitMu.RUnlock()

var retryAfter time.Duration
if resumeTime.IsZero() || time.Now().After(resumeTime) {
rateLimitMu.Lock()
Copy link
Collaborator

@rosecodym rosecodym Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible for a bunch of workers to get blocked here, right? And if they entered this function because of the secondary rate limit, won't they therefore all end up waiting for their rate limit period in sequence?

I'm having trouble juggling these locks in my head so I might be completely misunderstanding this. (Or maybe this is a theoretical possibility that's unlikely in practice?)

Copy link
Contributor Author

@rgmz rgmz Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it's possible that multiple workers will hit this, the actual work performed in this block is minimal and the lock should be released fairly quickly. time.Sleep is called after rateLimitMu.Unlock().

rateLimitMu.Unlock()
} else {
retryAfter = time.Until(resumeTime)
}
githubNumRateLimitEncountered.WithLabelValues(s.name).Inc()
time.Sleep(retryAfter)

Edit: I could always add another check after the lock to see if the resume time is still zero or in the past.


var (
now = time.Now()

// GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors.
limitType string
rateLimit *github.RateLimitError
abuseLimit *github.AbuseRateLimitError
)
if errors.As(errIn, &rateLimit) {
limitType = "primary"
rate := rateLimit.Rate
if rate.Remaining == 0 { // TODO: Will we ever receive a |RateLimitError| when remaining > 0?
retryAfter = rate.Reset.Sub(now)
}
} else if errors.As(errIn, &abuseLimit) {
limitType = "secondary"
retryAfter = abuseLimit.GetRetryAfter()
} else {
rateLimitMu.Unlock()
return false
}

resetTime, err := strconv.Atoi(res.Header.Get("x-ratelimit-reset"))
if err != nil || resetTime == 0 {
knownWait = false
} else if resetTime > 0 {
retryAfter = time.Duration(int64(resetTime)-time.Now().Unix()) * time.Second
jitter := time.Duration(rand.Intn(10)+1) * time.Second
rgmz marked this conversation as resolved.
Show resolved Hide resolved
if retryAfter > 0 {
retryAfter = retryAfter + jitter
rateLimitResumeTime = now.Add(retryAfter)
s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
} else {
retryAfter = (5 * time.Minute) + jitter
rateLimitResumeTime = now.Add(retryAfter)
// TODO: Use exponential backoff instead of static retry time.
s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
}
}

resumeTime := time.Now().Add(retryAfter).String()
if knownWait && remaining == 0 && retryAfter > 0 {
s.log.V(2).Info("rate limited", "retry_after", retryAfter.String(), "resume_time", resumeTime)
rateLimitMu.Unlock()
} else {
// TODO: Use exponential backoff instead of static retry time.
retryAfter = time.Minute * 5
s.log.V(2).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", resumeTime)
retryAfter = time.Until(resumeTime)
}

githubNumRateLimitEncountered.WithLabelValues(s.name).Inc()
time.Sleep(retryAfter)
githubSecondsSpentRateLimited.WithLabelValues(s.name).Add(retryAfter.Seconds())
return true
Expand All @@ -940,10 +957,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
logger := s.log.WithValues("user", user)
for {
gists, res, err := s.apiClient.Gists.List(ctx, user, gistOpts)
if err == nil {
res.Body.Close()
}
if handled := s.handleRateLimit(err, res); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -996,11 +1010,8 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
},
}
for {
orgs, resp, err := s.apiClient.Organizations.ListAll(ctx, orgOpts)
if err == nil {
resp.Body.Close()
}
rgmz marked this conversation as resolved.
Show resolved Hide resolved
if handled := s.handleRateLimit(err, resp); handled {
orgs, _, err := s.apiClient.Organizations.ListAll(ctx, orgOpts)
if s.handleRateLimit(err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1037,10 +1048,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
logger := s.log.WithValues("user", user)
for {
orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts)
if err == nil {
resp.Body.Close()
}
if handled := s.handleRateLimit(err, resp); handled {
if handled := s.handleRateLimit(err); handled {
continue
}
if err != nil {
Expand Down Expand Up @@ -1075,10 +1083,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
logger := s.log.WithValues("org", org)
for {
members, res, err := s.apiClient.Organizations.ListMembers(ctx, org, opts)
if err == nil {
defer res.Body.Close()
}
if handled := s.handleRateLimit(err, res); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil || len(members) == 0 {
Expand Down Expand Up @@ -1150,8 +1155,8 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm
Page: initialPage,
}
for {
comments, resp, err := s.apiClient.Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err, resp) {
comments, _, err := s.apiClient.Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err) {
break
}
if err != nil {
Expand Down Expand Up @@ -1254,8 +1259,8 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch
}

for {
issues, resp, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts)
if s.handleRateLimit(err, resp) {
issues, _, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts)
if s.handleRateLimit(err) {
break
}

Expand Down Expand Up @@ -1287,8 +1292,8 @@ func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunks
}

for {
issueComments, resp, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts)
if s.handleRateLimit(err, resp) {
issueComments, _, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts)
if s.handleRateLimit(err) {
break
}

Expand Down Expand Up @@ -1321,8 +1326,8 @@ func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan
}

for {
prs, resp, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts)
if s.handleRateLimit(err, resp) {
prs, _, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts)
if s.handleRateLimit(err) {
break
}

Expand Down Expand Up @@ -1354,8 +1359,8 @@ func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksCha
}

for {
prComments, resp, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts)
if s.handleRateLimit(err, resp) {
prComments, _, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts)
if s.handleRateLimit(err) {
break
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/sources/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/go-github/v42/github"
"github.com/google/go-github/v57/github"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
Expand Down Expand Up @@ -330,13 +330,13 @@ func TestNormalizeRepos(t *testing.T) {

func TestHandleRateLimit(t *testing.T) {
s := initTestSource(nil)
assert.False(t, s.handleRateLimit(nil, nil))
assert.False(t, s.handleRateLimit(nil))

err := &github.RateLimitError{}
res := &github.Response{Response: &http.Response{Header: make(http.Header)}}
res.Header.Set("x-ratelimit-remaining", "0")
res.Header.Set("x-ratelimit-reset", strconv.FormatInt(time.Now().Unix()+1, 10))
assert.True(t, s.handleRateLimit(err, res))
assert.True(t, s.handleRateLimit(err))
}

func TestEnumerateUnauthenticated(t *testing.T) {
Expand Down
Loading
Loading