Skip to content

Commit

Permalink
GitHub source logger clean up (#3269)
Browse files Browse the repository at this point in the history
* GitHub source logger clean up

* applied pr comments

* applied pr comments

* applied pr comments

* applied PR review comments
  • Loading branch information
LaraCroftDev authored Sep 9, 2024
1 parent 8a4d62c commit 17f6c98
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 54 deletions.
6 changes: 3 additions & 3 deletions pkg/sources/github/connector_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ type tokenConnector struct {
apiClient *github.Client
token string
isGitHubEnterprise bool
handleRateLimit func(error) bool
handleRateLimit func(context.Context, error) bool
user string
userMu sync.Mutex
}

var _ connector = (*tokenConnector)(nil)

func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(error) bool) (*tokenConnector, error) {
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) {
const httpTimeoutSeconds = 60
httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds))
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
Expand Down Expand Up @@ -68,7 +68,7 @@ func (c *tokenConnector) getUser(ctx context.Context) (string, error) {
)
for {
user, _, err = c.apiClient.Users.Get(ctx, "")
if c.handleRateLimit(err) {
if c.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down
84 changes: 41 additions & 43 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ type Source struct {
scanOptMu sync.Mutex // protects the scanOptions
scanOptions *git.ScanOptions

log logr.Logger
conn *sourcespb.GitHub
jobPool *errgroup.Group
resumeInfoMutex sync.Mutex
Expand Down Expand Up @@ -117,21 +116,21 @@ type filteredRepoCache struct {
include, exclude []glob.Glob
}

func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
func (s *Source) newFilteredRepoCache(ctx context.Context, c cache.Cache[string], include, exclude []string) *filteredRepoCache {
includeGlobs := make([]glob.Glob, 0, len(include))
excludeGlobs := make([]glob.Glob, 0, len(exclude))
for _, ig := range include {
g, err := glob.Compile(ig)
if err != nil {
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
ctx.Logger().V(1).Info("invalid include glob", "include_value", ig, "err", err)
continue
}
includeGlobs = append(includeGlobs, g)
}
for _, eg := range exclude {
g, err := glob.Compile(eg)
if err != nil {
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
ctx.Logger().V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
continue
}
excludeGlobs = append(excludeGlobs, g)
Expand Down Expand Up @@ -180,8 +179,6 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
return err
}

s.log = aCtx.Logger()

s.name = name
s.sourceID = sourceID
s.jobID = jobID
Expand All @@ -208,7 +205,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
}
s.memberCache = make(map[string]struct{})

s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
s.filteredRepoCache = s.newFilteredRepoCache(aCtx,
memory.New[string](),
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
s.conn.GetIgnoreRepos(),
)
Expand Down Expand Up @@ -360,7 +358,7 @@ RepoLoop:
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
if s.handleRateLimit(err) {
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
Expand All @@ -374,7 +372,7 @@ RepoLoop:
// Cache repository info.
for {
ghRepo, _, err := s.connector.APIClient().Repositories.Get(repoCtx, urlParts[1], urlParts[2])
if s.handleRateLimit(err) {
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
Expand All @@ -389,8 +387,7 @@ RepoLoop:
s.repos = append(s.repos, repo)
}
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))

ctx.Logger().Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
// We must sort the repos so we can resume later if necessary.
sort.Strings(s.repos)
return nil
Expand All @@ -417,7 +414,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error {

func (s *Source) enumerateUnauthenticated(ctx context.Context) {
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}

for _, org := range s.orgsCache.Keys() {
Expand All @@ -441,7 +438,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
var err error
for {
ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "")
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand All @@ -454,10 +451,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
if !specificScope {
// Enumerate the user's orgs and repos if none were specified.
if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil {
s.log.Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
}
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil {
s.log.Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
}

if isGithubEnterprise {
Expand Down Expand Up @@ -486,7 +483,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
}

if s.conn.ScanUsers && len(s.memberCache) > 0 {
s.log.Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
s.addReposForMembers(ctx)
}
}
Expand All @@ -507,9 +504,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
if err != nil {
return err
}
s.log.Info("Scanning repos", "org_members", len(s.memberCache))
ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache))
for member := range s.memberCache {
logger := s.log.WithValues("member", member)
logger := ctx.Logger().WithValues("member", member)
if err := s.addUserGistsToCache(ctx, member); err != nil {
logger.Error(err, "error fetching gists by user")
}
Expand All @@ -536,7 +533,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
var scannedCount uint64 = 1

s.log.V(2).Info("Found repos to scan", "count", len(s.repos))
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))

// If there is resume information available, limit this scan to only the repos that still need scanning.
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
Expand Down Expand Up @@ -574,7 +571,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
s.log.Error(err, "failed to scan repository")
ctx.Logger().Error(err, "failed to scan repository")
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
Expand Down Expand Up @@ -618,7 +615,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error

_ = s.jobPool.Wait()
if scanErrs.Count() > 0 {
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
ctx.Logger().Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
}
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")

Expand Down Expand Up @@ -666,7 +663,7 @@ var (
// Authenticated users have a rate limit of 5,000 requests per hour,
// however, certain actions are subject to a stricter "secondary" limit.
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
func (s *Source) handleRateLimit(errIn error) bool {
func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
if errIn == nil {
return false
}
Expand Down Expand Up @@ -705,12 +702,12 @@ func (s *Source) handleRateLimit(errIn error) bool {
if retryAfter > 0 {
retryAfter = retryAfter + jitter
rateLimitResumeTime = now.Add(retryAfter)
s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
} else {
retryAfter = (5 * time.Minute) + jitter
rateLimitResumeTime = now.Add(retryAfter)
// TODO: Use exponential backoff instead of static retry time.
s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
ctx.Logger().Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
}

rateLimitMu.Unlock()
Expand All @@ -725,13 +722,13 @@ func (s *Source) handleRateLimit(errIn error) bool {
}

func (s *Source) addReposForMembers(ctx context.Context) {
s.log.Info("Fetching repos from members", "members", len(s.memberCache))
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
for member := range s.memberCache {
if err := s.addUserGistsToCache(ctx, member); err != nil {
s.log.Info("Unable to fetch gists by user", "user", member, "error", err)
ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err)
}
if err := s.getReposByUser(ctx, member); err != nil {
s.log.Info("Unable to fetch repos by user", "user", member, "error", err)
ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err)
}
}
}
Expand All @@ -740,10 +737,11 @@ func (s *Source) addReposForMembers(ctx context.Context) {
// and adds them to the filteredRepoCache.
func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
gistOpts := &github.GistListOptions{}
logger := s.log.WithValues("user", user)
logger := ctx.Logger().WithValues("user", user)

for {
gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -788,7 +786,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
}

func (s *Source) addAllVisibleOrgs(ctx context.Context) {
s.log.V(2).Info("enumerating all visible organizations on GHE")
ctx.Logger().V(2).Info("enumerating all visible organizations on GHE")
// Enumeration on this endpoint does not use pages it uses a since ID.
// The endpoint will return organizations with an ID greater than the given since ID.
// Empty org response is our cue to break the enumeration loop.
Expand All @@ -800,11 +798,11 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}
for {
orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
s.log.Error(err, "could not list all organizations")
ctx.Logger().Error(err, "could not list all organizations")
return
}

Expand All @@ -813,7 +811,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}

lastOrgID := *orgs[len(orgs)-1].ID
s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
ctx.Logger().V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
orgOpts.Since = lastOrgID

for _, org := range orgs {
Expand All @@ -827,7 +825,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
continue
}
s.orgsCache.Set(name, name)
s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
ctx.Logger().V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
}
}
}
Expand All @@ -836,10 +834,10 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
orgOpts := &github.ListOptions{
PerPage: defaultPagination,
}
logger := s.log.WithValues("user", user)
logger := ctx.Logger().WithValues("user", user)
for {
orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -869,10 +867,10 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
},
}

logger := s.log.WithValues("org", org)
logger := ctx.Logger().WithValues("org", org)
for {
members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil || len(members) == 0 {
Expand Down Expand Up @@ -994,7 +992,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
}
for {
comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1107,7 +1105,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha

for {
issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}

Expand Down Expand Up @@ -1179,7 +1177,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch

for {
issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1244,7 +1242,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c

for {
prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1276,7 +1274,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk

for {
prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
Expand Down
8 changes: 3 additions & 5 deletions pkg/sources/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"testing"
"time"

"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/go-github/v63/github"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -369,7 +368,8 @@ func TestNormalizeRepos(t *testing.T) {

func TestHandleRateLimit(t *testing.T) {
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
assert.False(t, s.handleRateLimit(nil))
ctx := context.Background()
assert.False(t, s.handleRateLimit(ctx, nil))

// Request
reqUrl, _ := url.Parse("https://github.com/trufflesecurity/trufflehog")
Expand Down Expand Up @@ -400,7 +400,7 @@ func TestHandleRateLimit(t *testing.T) {
Message: "Too Many Requests",
}

assert.True(t, s.handleRateLimit(err))
assert.True(t, s.handleRateLimit(ctx, err))
}

func TestEnumerateUnauthenticated(t *testing.T) {
Expand Down Expand Up @@ -721,7 +721,6 @@ func Test_setProgressCompleteWithRepo_resumeInfo(t *testing.T) {

s := &Source{
repos: []string{},
log: logr.Discard(),
}

for _, tt := range tests {
Expand Down Expand Up @@ -772,7 +771,6 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
for _, tt := range tests {
s := &Source{
repos: tt.repos,
log: logr.Discard(),
}

s.setProgressCompleteWithRepo(tt.index, tt.offset, "")
Expand Down
Loading

0 comments on commit 17f6c98

Please sign in to comment.