Skip to content

Commit

Permalink
[feat] - S3 metrics (#3577)
Browse files Browse the repository at this point in the history
* add config option for s3 resumption

* updates

* initial progress tracking logic

* more testing

* revert s3 source file

* UpdateScanProgress tests

* adjust

* updates

* invert

* updates

* updates

* fix

* update

* adjust test

* fix

* remove progress tracking

* cleanup

* cleanup

* remove dupe

* add metrics to s3 scan

* make collector a singleton

* address comments

* fix

* remove
  • Loading branch information
ahrav authored Nov 26, 2024
1 parent 33879e4 commit 7b3d98d
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 5 deletions.
99 changes: 99 additions & 0 deletions pkg/sources/s3/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package s3

import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
)

// metricsCollector defines the interface for recording S3 scan metrics.
type metricsCollector interface {
// Object metrics.

RecordObjectScanned(bucket string)
RecordObjectSkipped(bucket, reason string)
RecordObjectError(bucket string)

// Role metrics.

RecordRoleScanned(roleArn string)
RecordBucketForRole(roleArn string)
}

type collector struct {
objectsScanned *prometheus.CounterVec
objectsSkipped *prometheus.CounterVec
objectsErrors *prometheus.CounterVec
rolesScanned *prometheus.GaugeVec
bucketsPerRole *prometheus.GaugeVec
}

var metricsInstance metricsCollector

func init() {
metricsInstance = &collector{
objectsScanned: promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "objects_scanned_total",
Help: "Total number of S3 objects successfully scanned",
}, []string{"bucket"}),

objectsSkipped: promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "objects_skipped_total",
Help: "Total number of S3 objects skipped during scan",
}, []string{"bucket", "reason"}),

objectsErrors: promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "objects_errors_total",
Help: "Total number of errors encountered during S3 scan",
}, []string{"bucket"}),

rolesScanned: promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "roles_scanned",
Help: "Number of AWS roles being scanned",
}, []string{"role_arn"}),

bucketsPerRole: promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "buckets_per_role",
Help: "Number of buckets accessible per AWS role",
}, []string{"role_arn"}),
}
}

func (c *collector) RecordObjectScanned(bucket string) {
c.objectsScanned.WithLabelValues(bucket).Inc()
}

func (c *collector) RecordObjectSkipped(bucket, reason string) {
c.objectsSkipped.WithLabelValues(bucket, reason).Inc()
}

func (c *collector) RecordObjectError(bucket string) {
c.objectsErrors.WithLabelValues(bucket).Inc()
}

const defaultRoleARN = "default"

func (c *collector) RecordRoleScanned(roleArn string) {
if roleArn == "" {
roleArn = defaultRoleARN
}
c.rolesScanned.WithLabelValues(roleArn).Set(1)
}

func (c *collector) RecordBucketForRole(roleArn string) {
if roleArn == "" {
roleArn = defaultRoleARN
}
c.bucketsPerRole.WithLabelValues(roleArn).Inc()
}
31 changes: 26 additions & 5 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Source struct {

checkpointer *Checkpointer
sources.Progress
metricsCollector metricsCollector

errorCount *sync.Map
jobPool *errgroup.Group
Expand Down Expand Up @@ -94,6 +95,7 @@ func (s *Source) Init(
s.conn = &conn

s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress)
s.metricsCollector = metricsInstance

s.setMaxObjectSize(conn.GetMaxObjectSize())

Expand All @@ -106,11 +108,12 @@ func (s *Source) Init(

func (s *Source) Validate(ctx context.Context) []error {
var errs []error
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
if len(roleErrs) > 0 {
errs = append(errs, roleErrs...)
}
return nil
}

if err := s.visitRoles(ctx, visitor); err != nil {
Expand Down Expand Up @@ -307,6 +310,7 @@ func (s *Source) scanBuckets(

bucketsToScanCount := len(bucketsToScan)
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
s.metricsCollector.RecordBucketForRole(role)
bucket := bucketsToScan[bucketIdx]
ctx := context.WithValue(ctx, "bucket", bucket)

Expand Down Expand Up @@ -385,8 +389,9 @@ func (s *Source) scanBuckets(

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
return nil
}

return s.visitRoles(ctx, visitor)
Expand Down Expand Up @@ -427,6 +432,7 @@ func (s *Source) pageChunker(

for objIdx, obj := range metadata.page.Contents {
if obj == nil {
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "nil_object")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for nil object")
}
Expand All @@ -442,6 +448,7 @@ func (s *Source) pageChunker(
// Skip GLACIER and GLACIER_IR objects.
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for glacier object")
}
Expand All @@ -451,6 +458,7 @@ func (s *Source) pageChunker(
// Ignore large files.
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for large file")
}
Expand All @@ -460,6 +468,7 @@ func (s *Source) pageChunker(
// File empty file.
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for empty file")
}
Expand All @@ -469,6 +478,7 @@ func (s *Source) pageChunker(
// Skip incompatible extensions.
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for incompatible file")
}
Expand All @@ -483,6 +493,7 @@ func (s *Source) pageChunker(

if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "directory")
return nil
}

Expand All @@ -508,8 +519,12 @@ func (s *Source) pageChunker(
Key: obj.Key,
})
if err != nil {
if !strings.Contains(err.Error(), "AccessDenied") {
if strings.Contains(err.Error(), "AccessDenied") {
ctx.Logger().Error(err, "could not get S3 object; access denied")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "access_denied")
} else {
ctx.Logger().Error(err, "could not get S3 object")
s.metricsCollector.RecordObjectError(metadata.bucket)
}
// According to the documentation for GetObjectWithContext,
// the response can be non-nil even if there was an error.
Expand Down Expand Up @@ -563,6 +578,7 @@ func (s *Source) pageChunker(

if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
ctx.Logger().Error(err, "error handling file")
s.metricsCollector.RecordObjectError(metadata.bucket)
return nil
}

Expand All @@ -580,6 +596,7 @@ func (s *Source) pageChunker(
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}
s.metricsCollector.RecordObjectScanned(metadata.bucket)

return nil
})
Expand Down Expand Up @@ -633,14 +650,16 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
) error {
roles := s.conn.GetRoles()
if len(roles) == 0 {
roles = []string{""}
}

for _, role := range roles {
s.metricsCollector.RecordRoleScanned(role)

client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return fmt.Errorf("could not create s3 client: %w", err)
Expand All @@ -651,7 +670,9 @@ func (s *Source) visitRoles(
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
}

f(ctx, client, role, bucketsToScan)
if err := f(ctx, client, role, bucketsToScan); err != nil {
return err
}
}

return nil
Expand Down
31 changes: 31 additions & 0 deletions pkg/sources/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,37 @@ func TestSource_ChunksLarge(t *testing.T) {
assert.Equal(t, got, wantChunkCount)
}

func TestSourceChunksNoResumption(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test-2"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

wantChunkCount := 19787
got := 0

for range chunksCh {
got++
}
assert.Equal(t, got, wantChunkCount)
}

func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
Expand Down

0 comments on commit 7b3d98d

Please sign in to comment.