From 6923843bc4046d849dc561b134546d198e9b7f15 Mon Sep 17 00:00:00 2001 From: ahrav Date: Mon, 4 Nov 2024 08:54:30 -0800 Subject: [PATCH] [chore] - minor cleanup S3 source (#3554) * cleanup s3 source * revert --- pkg/sources/s3/s3.go | 132 +++++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index c32faa0730f1..27c9e9b4e5e7 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -15,8 +15,6 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/sts" "github.com/go-errors/errors" - "github.com/go-logr/logr" - "github.com/trufflesecurity/trufflehog/v3/pkg/log" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -24,6 +22,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/handlers" + "github.com/trufflesecurity/trufflehog/v3/pkg/log" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer" @@ -40,16 +39,17 @@ const ( type Source struct { name string - sourceId sources.SourceID - jobId sources.JobID + sourceID sources.SourceID + jobID sources.JobID verify bool concurrency int - log logr.Logger sources.Progress + conn *sourcespb.S3 + errorCount *sync.Map - conn *sourcespb.S3 jobPool *errgroup.Group maxObjectSize int64 + sources.CommonSourceUnitUnmarshaller } @@ -59,43 +59,41 @@ var _ sources.SourceUnitUnmarshaller = (*Source)(nil) var _ sources.Validator = (*Source)(nil) // Type returns the type of source -func (s *Source) Type() sourcespb.SourceType { - return SourceType -} +func (s *Source) Type() sourcespb.SourceType { return SourceType } -func (s *Source) SourceID() sources.SourceID { - return s.sourceId -} +func (s *Source) SourceID() sources.SourceID { return s.sourceID } -func (s *Source) JobID() sources.JobID { - return s.jobId -} +func (s *Source) JobID() sources.JobID { return s.jobID } // Init returns an initialized AWS source -func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.log = context.WithValues(aCtx, "source", s.Type(), "name", name).Logger() - +func (s *Source) Init( + _ context.Context, + name string, + jobID sources.JobID, + sourceID sources.SourceID, + verify bool, + connection *anypb.Any, + concurrency int, +) error { s.name = name - s.sourceId = sourceId - s.jobId = jobId + s.sourceID = sourceID + s.jobID = jobID s.verify = verify s.concurrency = concurrency s.errorCount = &sync.Map{} - s.log = aCtx.Logger() s.jobPool = &errgroup.Group{} s.jobPool.SetLimit(concurrency) var conn sourcespb.S3 - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) - if err != nil { - return errors.WrapPrefix(err, "error unmarshalling connection", 0) + if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil { + return fmt.Errorf("error unmarshalling connection: %w", err) } s.conn = &conn s.setMaxObjectSize(conn.GetMaxObjectSize()) - if len(conn.Buckets) > 0 && len(conn.IgnoreBuckets) > 0 { - return fmt.Errorf("either a bucket include list or a bucket ignore list can be specified, but not both") + if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 { + return errors.New("either a bucket include list or a bucket ignore list can be specified, but not both") } return nil @@ -110,8 +108,7 @@ func (s *Source) Validate(ctx context.Context) []error { } } - err := s.visitRoles(ctx, visitor) - if err != nil { + if err := s.visitRoles(ctx, visitor); err != nil { errs = append(errs, err) } @@ -136,11 +133,15 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) { switch cred := s.conn.GetCredential().(type) { case *sourcespb.S3_SessionToken: - cfg.Credentials = credentials.NewStaticCredentials(cred.SessionToken.Key, cred.SessionToken.Secret, cred.SessionToken.SessionToken) + cfg.Credentials = credentials.NewStaticCredentials( + cred.SessionToken.GetKey(), + cred.SessionToken.GetSecret(), + cred.SessionToken.GetSessionToken(), + ) log.RedactGlobally(cred.SessionToken.GetSecret()) log.RedactGlobally(cred.SessionToken.GetSessionToken()) case *sourcespb.S3_AccessKey: - cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.Key, cred.AccessKey.Secret, "") + cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.GetKey(), cred.AccessKey.GetSecret(), "") log.RedactGlobally(cred.AccessKey.GetSecret()) case *sourcespb.S3_Unauthenticated: cfg.Credentials = credentials.AnonymousCredentials @@ -174,12 +175,12 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) { // IAM identity needs s3:ListBuckets permission func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { - if len(s.conn.Buckets) > 0 { - return s.conn.Buckets, nil + if buckets := s.conn.GetBuckets(); len(buckets) > 0 { + return buckets, nil } - ignore := make(map[string]struct{}, len(s.conn.IgnoreBuckets)) - for _, bucket := range s.conn.IgnoreBuckets { + ignore := make(map[string]struct{}, len(s.conn.GetIgnoreBuckets())) + for _, bucket := range s.conn.GetIgnoreBuckets() { ignore[bucket] = struct{}{} } @@ -198,27 +199,32 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { return bucketsToScan, nil } -func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) { - objectCount := uint64(0) +func (s *Source) scanBuckets( + ctx context.Context, + client *s3.S3, + role string, + bucketsToScan []string, + chunksChan chan *sources.Chunk, +) { + var objectCount uint64 - logger := s.log if role != "" { - logger = logger.WithValues("roleArn", role) + ctx = context.WithValue(ctx, "role", role) } for i, bucket := range bucketsToScan { - logger := logger.WithValues("bucket", bucket) + ctx := context.WithValue(ctx, "bucket", bucket) if common.IsDone(ctx) { return } s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "") - logger.V(3).Info("Scanning bucket") + ctx.Logger().V(3).Info("Scanning bucket") regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket) if err != nil { - logger.Error(err, "could not get regional client for bucket") + ctx.Logger().Error(err, "could not get regional client for bucket") continue } @@ -226,25 +232,29 @@ func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bu err = regionalClient.ListObjectsV2PagesWithContext( ctx, &s3.ListObjectsV2Input{Bucket: &bucket}, - func(page *s3.ListObjectsV2Output, last bool) bool { + func(page *s3.ListObjectsV2Output, _ bool) bool { s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount) return true }) if err != nil { if role == "" { - logger.Error(err, "could not list objects in bucket") + ctx.Logger().Error(err, "could not list objects in bucket") } else { // Our documentation blesses specifying a role to assume without specifying buckets to scan, which will // often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the // account, but the role probably doesn't have access to all of them). This makes it expected behavior // and therefore not an error. - logger.V(3).Info("could not list objects in bucket", - "err", err) + ctx.Logger().V(3).Info("could not list objects in bucket", "err", err) } } } - s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "") + s.SetProgressComplete( + len(bucketsToScan), + len(bucketsToScan), + fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), + "", + ) } // Chunks emits chunks of bytes over a channel. @@ -256,10 +266,15 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ . return s.visitRoles(ctx, visitor) } -func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) { +func (s *Source) getRegionalClientForBucket( + ctx context.Context, + defaultRegionClient *s3.S3, + role string, + bucket string, +) (*s3.S3, error) { region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket) if err != nil { - return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0) + return nil, fmt.Errorf("could not get s3 region for bucket: %s", bucket) } if region == defaultAWSRegion { @@ -268,7 +283,7 @@ func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionCl regionalClient, err := s.newClient(region, role) if err != nil { - return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0) + return nil, fmt.Errorf("could not create regional s3 client for bucket %s: %w", bucket, err) } return regionalClient, nil @@ -448,7 +463,6 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr } _, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket}) - if err == nil { wasAbleToListAnyBucket = true } else if shouldHaveAccessToAllBuckets { @@ -458,7 +472,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr if !wasAbleToListAnyBucket { if roleArn == "" { - errs = append(errs, fmt.Errorf("could not list objects in any bucket")) + errs = append(errs, errors.New("could not list objects in any bucket")) } else { errs = append(errs, fmt.Errorf("role %q could not list objects in any bucket", roleArn)) } @@ -467,8 +481,16 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr return errs } -func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error { - roles := s.conn.Roles +// visitRoles iterates over the configured AWS roles and calls the provided function +// for each role, passing in the default S3 client, the role ARN, and the list of +// buckets to scan. +// +// If no roles are configured, it will call the function with an empty role ARN. +func (s *Source) visitRoles( + ctx context.Context, + f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string), +) error { + roles := s.conn.GetRoles() if len(roles) == 0 { roles = []string{""} } @@ -476,7 +498,7 @@ func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defau for _, role := range roles { client, err := s.newClient(defaultAWSRegion, role) if err != nil { - return errors.WrapPrefix(err, "could not create s3 client", 0) + return fmt.Errorf("could not create s3 client: %w", err) } bucketsToScan, err := s.getBucketsToScan(client) @@ -493,7 +515,7 @@ func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defau // S3 links currently have the general format of: // https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key] func makeS3Link(bucket, region, key string) string { - if region == "us-east-1" { + if region == defaultAWSRegion { region = "" } else { region = "." + region