Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] - minor cleanup S3 source #3554

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 77 additions & 55 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/go-errors/errors"
"github.com/go-logr/logr"
"github.com/trufflesecurity/trufflehog/v3/pkg/log"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
"github.com/trufflesecurity/trufflehog/v3/pkg/log"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
Expand All @@ -40,16 +39,17 @@ const (

type Source struct {
name string
sourceId sources.SourceID
jobId sources.JobID
sourceID sources.SourceID
jobID sources.JobID
verify bool
concurrency int
log logr.Logger
sources.Progress
conn *sourcespb.S3

errorCount *sync.Map
conn *sourcespb.S3
jobPool *errgroup.Group
maxObjectSize int64

sources.CommonSourceUnitUnmarshaller
}

Expand All @@ -59,43 +59,41 @@ var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType {
return SourceType
}
func (s *Source) Type() sourcespb.SourceType { return SourceType }

func (s *Source) SourceID() sources.SourceID {
return s.sourceId
}
func (s *Source) SourceID() sources.SourceID { return s.sourceID }

func (s *Source) JobID() sources.JobID {
return s.jobId
}
func (s *Source) JobID() sources.JobID { return s.jobID }

// Init returns an initialized AWS source
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
s.log = context.WithValues(aCtx, "source", s.Type(), "name", name).Logger()

func (s *Source) Init(
_ context.Context,
name string,
jobID sources.JobID,
sourceID sources.SourceID,
verify bool,
connection *anypb.Any,
concurrency int,
) error {
s.name = name
s.sourceId = sourceId
s.jobId = jobId
s.sourceID = sourceID
s.jobID = jobID
s.verify = verify
s.concurrency = concurrency
s.errorCount = &sync.Map{}
s.log = aCtx.Logger()
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)

var conn sourcespb.S3
err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
return errors.WrapPrefix(err, "error unmarshalling connection", 0)
if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil {
return fmt.Errorf("error unmarshalling connection: %w", err)
}
s.conn = &conn

s.setMaxObjectSize(conn.GetMaxObjectSize())

if len(conn.Buckets) > 0 && len(conn.IgnoreBuckets) > 0 {
return fmt.Errorf("either a bucket include list or a bucket ignore list can be specified, but not both")
if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
return errors.New("either a bucket include list or a bucket ignore list can be specified, but not both")
}

return nil
Expand All @@ -110,8 +108,7 @@ func (s *Source) Validate(ctx context.Context) []error {
}
}

err := s.visitRoles(ctx, visitor)
if err != nil {
if err := s.visitRoles(ctx, visitor); err != nil {
errs = append(errs, err)
}

Expand All @@ -136,11 +133,15 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {

switch cred := s.conn.GetCredential().(type) {
case *sourcespb.S3_SessionToken:
cfg.Credentials = credentials.NewStaticCredentials(cred.SessionToken.Key, cred.SessionToken.Secret, cred.SessionToken.SessionToken)
cfg.Credentials = credentials.NewStaticCredentials(
cred.SessionToken.GetKey(),
cred.SessionToken.GetSecret(),
cred.SessionToken.GetSessionToken(),
)
log.RedactGlobally(cred.SessionToken.GetSecret())
log.RedactGlobally(cred.SessionToken.GetSessionToken())
case *sourcespb.S3_AccessKey:
cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.Key, cred.AccessKey.Secret, "")
cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.GetKey(), cred.AccessKey.GetSecret(), "")
log.RedactGlobally(cred.AccessKey.GetSecret())
case *sourcespb.S3_Unauthenticated:
cfg.Credentials = credentials.AnonymousCredentials
Expand Down Expand Up @@ -174,12 +175,12 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {

// IAM identity needs s3:ListBuckets permission
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
if len(s.conn.Buckets) > 0 {
return s.conn.Buckets, nil
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
return buckets, nil
}

ignore := make(map[string]struct{}, len(s.conn.IgnoreBuckets))
for _, bucket := range s.conn.IgnoreBuckets {
ignore := make(map[string]struct{}, len(s.conn.GetIgnoreBuckets()))
for _, bucket := range s.conn.GetIgnoreBuckets() {
ignore[bucket] = struct{}{}
}

Expand All @@ -198,53 +199,62 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
return bucketsToScan, nil
}

func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) {
objectCount := uint64(0)
func (s *Source) scanBuckets(
ctx context.Context,
client *s3.S3,
role string,
bucketsToScan []string,
chunksChan chan *sources.Chunk,
) {
var objectCount uint64

logger := s.log
if role != "" {
logger = logger.WithValues("roleArn", role)
ctx = context.WithValue(ctx, "role", role)
}

for i, bucket := range bucketsToScan {
logger := logger.WithValues("bucket", bucket)
ctx := context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
return
}

s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
logger.V(3).Info("Scanning bucket")
ctx.Logger().V(3).Info("Scanning bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
logger.Error(err, "could not get regional client for bucket")
ctx.Logger().Error(err, "could not get regional client for bucket")
continue
}

errorCount := sync.Map{}

err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
func(page *s3.ListObjectsV2Output, last bool) bool {
func(page *s3.ListObjectsV2Output, _ bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
return true
})

if err != nil {
if role == "" {
logger.Error(err, "could not list objects in bucket")
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
logger.V(3).Info("could not list objects in bucket",
"err", err)
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
s.SetProgressComplete(
len(bucketsToScan),
len(bucketsToScan),
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
"",
)
}

// Chunks emits chunks of bytes over a channel.
Expand All @@ -256,10 +266,15 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
return s.visitRoles(ctx, visitor)
}

func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) {
func (s *Source) getRegionalClientForBucket(
ctx context.Context,
defaultRegionClient *s3.S3,
role string,
bucket string,
) (*s3.S3, error) {
region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket)
if err != nil {
return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0)
return nil, fmt.Errorf("could not get s3 region for bucket: %s", bucket)
}

if region == defaultAWSRegion {
Expand All @@ -268,7 +283,7 @@ func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionCl

regionalClient, err := s.newClient(region, role)
if err != nil {
return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0)
return nil, fmt.Errorf("could not create regional s3 client for bucket %s: %w", bucket, err)
}

return regionalClient, nil
Expand Down Expand Up @@ -448,7 +463,6 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
}

_, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket})

if err == nil {
wasAbleToListAnyBucket = true
} else if shouldHaveAccessToAllBuckets {
Expand All @@ -458,7 +472,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr

if !wasAbleToListAnyBucket {
if roleArn == "" {
errs = append(errs, fmt.Errorf("could not list objects in any bucket"))
errs = append(errs, errors.New("could not list objects in any bucket"))
} else {
errs = append(errs, fmt.Errorf("role %q could not list objects in any bucket", roleArn))
}
Expand All @@ -467,16 +481,24 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
return errs
}

func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error {
roles := s.conn.Roles
// visitRoles iterates over the configured AWS roles and calls the provided function
// for each role, passing in the default S3 client, the role ARN, and the list of
// buckets to scan.
//
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
) error {
roles := s.conn.GetRoles()
if len(roles) == 0 {
roles = []string{""}
}

for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
return fmt.Errorf("could not create s3 client: %w", err)
}

bucketsToScan, err := s.getBucketsToScan(client)
Expand All @@ -493,7 +515,7 @@ func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defau
// S3 links currently have the general format of:
// https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key]
func makeS3Link(bucket, region, key string) string {
if region == "us-east-1" {
if region == defaultAWSRegion {
region = ""
} else {
region = "." + region
Expand Down
Loading