diff --git a/pkg/engine/circleci.go b/pkg/engine/circleci.go index 2d588842bf41..5e74449ec3c6 100644 --- a/pkg/engine/circleci.go +++ b/pkg/engine/circleci.go @@ -1,13 +1,12 @@ package engine import ( - "runtime" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/circleci" ) @@ -29,10 +28,22 @@ func (e *Engine) ScanCircleCI(ctx context.Context, token string) error { sourceName := "trufflehog - Circle CI" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, circleci.SourceType) - circleSource := &circleci.Source{} - if err := circleSource.Init(ctx, "trufflehog - Circle CI", jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &circleci.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err = e.sourceManager.Run(ctx, sourceName, circleSource) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/docker.go b/pkg/engine/docker.go index 38ffcc73ea34..86e2e4d9a3e4 100644 --- a/pkg/engine/docker.go +++ b/pkg/engine/docker.go @@ -1,11 +1,10 @@ package engine import ( - "runtime" - "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/docker" ) @@ -14,10 +13,22 @@ func (e *Engine) ScanDocker(ctx context.Context, conn *anypb.Any) error { sourceName := "trufflehog - docker" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, docker.SourceType) - dockerSource := &docker.Source{} - if err := dockerSource.Init(ctx, sourceName, jobID, sourceID, true, conn, runtime.NumCPU()); err != nil { + src := &docker.Source{} + err := src.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err := e.sourceManager.Run(ctx, sourceName, dockerSource) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/filesystem.go b/pkg/engine/filesystem.go index 169e52c31d33..f0a08d1da6e0 100644 --- a/pkg/engine/filesystem.go +++ b/pkg/engine/filesystem.go @@ -1,8 +1,6 @@ package engine import ( - "runtime" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -29,10 +27,22 @@ func (e *Engine) ScanFileSystem(ctx context.Context, c sources.FilesystemConfig) sourceName := "trufflehog - filesystem" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, filesystem.SourceType) - fileSystemSource := &filesystem.Source{} - if err := fileSystemSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &filesystem.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err = e.sourceManager.Run(ctx, sourceName, fileSystemSource) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/gcs.go b/pkg/engine/gcs.go index f2bf82c98863..ea40699e0860 100644 --- a/pkg/engine/gcs.go +++ b/pkg/engine/gcs.go @@ -47,11 +47,23 @@ func (e *Engine) ScanGCS(ctx context.Context, c sources.GCSConfig) error { sourceName := "trufflehog - gcs" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, gcs.SourceType) - gcsSource := &gcs.Source{} - if err := gcsSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, int(c.Concurrency)); err != nil { + src := &gcs.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err = e.sourceManager.Run(ctx, sourceName, gcsSource) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/git.go b/pkg/engine/git.go index 411b2689e74c..4b13ebeec0c9 100644 --- a/pkg/engine/git.go +++ b/pkg/engine/git.go @@ -1,8 +1,6 @@ package engine import ( - "runtime" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -34,11 +32,22 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { sourceName := "trufflehog - git" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, git.SourceType) - gitSource := &git.Source{} - if err := gitSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &git.Source{} + err := src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err := e.sourceManager.Run(ctx, sourceName, gitSource) + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/github.go b/pkg/engine/github.go index 9bf84072a7d0..f5a71a1159d5 100644 --- a/pkg/engine/github.go +++ b/pkg/engine/github.go @@ -43,21 +43,33 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) error { return err } + sourceName := "trufflehog - github" + sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, github.SourceType) + + src := &github.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { + return err + } + logOptions := &gogit.LogOptions{} opts := []git.ScanOption{ git.ScanOptionFilter(c.Filter), git.ScanOptionLogOptions(logOptions), } scanOptions := git.NewScanOptions(opts...) + src.WithScanOptions(scanOptions) - sourceName := "trufflehog - github" - sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, github.SourceType) - - githubSource := &github.Source{} - if err := githubSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil { - return err - } - githubSource.WithScanOptions(scanOptions) - _, err = e.sourceManager.Run(ctx, sourceName, githubSource) + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/gitlab.go b/pkg/engine/gitlab.go index 73885caed08e..712f0b4dfa8e 100644 --- a/pkg/engine/gitlab.go +++ b/pkg/engine/gitlab.go @@ -2,7 +2,6 @@ package engine import ( "fmt" - "runtime" gogit "github.com/go-git/go-git/v5" "google.golang.org/protobuf/proto" @@ -17,13 +16,6 @@ import ( // ScanGitLab scans GitLab with the provided configuration. func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) error { - logOptions := &gogit.LogOptions{} - opts := []git.ScanOption{ - git.ScanOptionFilter(c.Filter), - git.ScanOptionLogOptions(logOptions), - } - scanOptions := git.NewScanOptions(opts...) - connection := &sourcespb.GitLab{SkipBinaries: c.SkipBinaries} switch { @@ -53,11 +45,30 @@ func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) error { sourceName := "trufflehog - gitlab" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, gitlab.SourceType) - gitlabSource := &gitlab.Source{} - if err := gitlabSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &gitlab.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - gitlabSource.WithScanOptions(scanOptions) - _, err = e.sourceManager.Run(ctx, sourceName, gitlabSource) + + logOptions := &gogit.LogOptions{} + opts := []git.ScanOption{ + git.ScanOptionFilter(c.Filter), + git.ScanOptionLogOptions(logOptions), + } + scanOptions := git.NewScanOptions(opts...) + src.WithScanOptions(scanOptions) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/s3.go b/pkg/engine/s3.go index acb68bc8c097..8eac237dece6 100644 --- a/pkg/engine/s3.go +++ b/pkg/engine/s3.go @@ -2,7 +2,6 @@ package engine import ( "fmt" - "runtime" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -61,10 +60,22 @@ func (e *Engine) ScanS3(ctx context.Context, c sources.S3Config) error { sourceName := "trufflehog - s3" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, s3.SourceType) - s3Source := &s3.Source{} - if err := s3Source.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &s3.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err = e.sourceManager.Run(ctx, sourceName, s3Source) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/syslog.go b/pkg/engine/syslog.go index 654318bbc4ce..f2fb7dcc7499 100644 --- a/pkg/engine/syslog.go +++ b/pkg/engine/syslog.go @@ -43,12 +43,23 @@ func (e *Engine) ScanSyslog(ctx context.Context, c sources.SyslogConfig) error { sourceName := "trufflehog - syslog" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, syslog.SourceType) - syslogSource := &syslog.Source{} - if err := syslogSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil { + + src := &syslog.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - syslogSource.InjectConnection(connection) - _, err = e.sourceManager.Run(ctx, sourceName, syslogSource) + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/engine/travisci.go b/pkg/engine/travisci.go index f1a002e95756..3d90e1609c9f 100644 --- a/pkg/engine/travisci.go +++ b/pkg/engine/travisci.go @@ -1,13 +1,12 @@ package engine import ( - "runtime" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/travisci" ) @@ -29,10 +28,22 @@ func (e *Engine) ScanTravisCI(ctx context.Context, token string) error { sourceName := "trufflehog - Travis CI" sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, travisci.SourceType) - travisSource := &travisci.Source{} - if err := travisSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil { + src := &travisci.Source{} + err = src.Init( + ctx, + sources.NewConfig( + &conn, + sources.WithName(sourceName), + sources.WithSourceID(sourceID), + sources.WithJobID(jobID), + sources.WithVerify(e.verify), + sources.WithConcurrency(int(e.concurrency)), + ), + ) + if err != nil { return err } - _, err = e.sourceManager.Run(ctx, sourceName, travisSource) + + _, err = e.sourceManager.Run(ctx, sourceName, src) return err } diff --git a/pkg/sources/circleci/circleci.go b/pkg/sources/circleci/circleci.go index def3145051ab..97ec0521e2e8 100644 --- a/pkg/sources/circleci/circleci.go +++ b/pkg/sources/circleci/circleci.go @@ -56,17 +56,17 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized CircleCI source. -func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify - s.jobPool = &errgroup.Group{} - s.jobPool.SetLimit(concurrency) +func (s *Source) Init(_ context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify + s.jobPool = new(errgroup.Group) + s.jobPool.SetLimit(cfg.Concurrency) s.client = common.RetryableHttpClientTimeout(3) var conn sourcespb.CircleCI - if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil { + if err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}); err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } diff --git a/pkg/sources/circleci/circleci_test.go b/pkg/sources/circleci/circleci_test.go index cf7e21f55a5d..b4666f347b19 100644 --- a/pkg/sources/circleci/circleci_test.go +++ b/pkg/sources/circleci/circleci_test.go @@ -69,7 +69,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithVerify(tt.init.verify), + sources.WithConcurrency(5), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/sources/docker/docker.go b/pkg/sources/docker/docker.go index 7088dcbceb57..f899d086dd56 100644 --- a/pkg/sources/docker/docker.go +++ b/pkg/sources/docker/docker.go @@ -56,18 +56,18 @@ func (s *Source) JobID() sources.JobID { } // Init initializes the source. -func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify - s.concurrency = concurrency +func (s *Source) Init(_ context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify + s.concurrency = cfg.Concurrency // Reset metrics for this source at initialization time. dockerImagesScanned.WithLabelValues(s.name).Set(0) dockerLayersScanned.WithLabelValues(s.name).Set(0) - if err := anypb.UnmarshalTo(connection, &s.conn, proto.UnmarshalOptions{}); err != nil { + if err := anypb.UnmarshalTo(cfg.Connection, &s.conn, proto.UnmarshalOptions{}); err != nil { return fmt.Errorf("error unmarshalling connection: %w", err) } diff --git a/pkg/sources/docker/docker_test.go b/pkg/sources/docker/docker_test.go index aa290c3409da..2cce50fa3936 100644 --- a/pkg/sources/docker/docker_test.go +++ b/pkg/sources/docker/docker_test.go @@ -26,7 +26,14 @@ func TestDockerImageScan(t *testing.T) { assert.NoError(t, err) s := &Source{} - err = s.Init(context.TODO(), "test source", 0, 0, false, conn, 1) + err = s.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test source"), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) var wg sync.WaitGroup @@ -63,7 +70,14 @@ func TestDockerImageScanWithDigest(t *testing.T) { assert.NoError(t, err) s := &Source{} - err = s.Init(context.TODO(), "test source", 0, 0, false, conn, 1) + err = s.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test source"), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) var wg sync.WaitGroup diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index 1d3d0cc8b7d7..a253e6a9a957 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -59,17 +59,17 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized Filesystem source. -func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { +func (s *Source) Init(aCtx context.Context, cfg *sources.Config) error { s.log = aCtx.Logger() - s.concurrency = concurrency - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify + s.concurrency = cfg.Concurrency + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify var conn sourcespb.Filesystem - if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil { + if err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}); err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } s.paths = append(conn.Paths, conn.Directories...) diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index 4a846a6737cc..8e50fad0400c 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -63,7 +63,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithVerify(tt.init.verify), + sources.WithConcurrency(5), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -162,7 +170,15 @@ func TestEnumerate(t *testing.T) { // Initialize the source. s := Source{} - err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test enumerate"), + sources.WithVerify(true), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) reporter := sourcestest.TestReporter{} @@ -198,7 +214,15 @@ func TestChunkUnit(t *testing.T) { // Initialize the source. s := Source{} - err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test chunk unit"), + sources.WithVerify(true), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) // Happy path single file. @@ -249,7 +273,15 @@ func TestEnumerateReporterErr(t *testing.T) { // Initialize the source. s := Source{} - err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test enumerate"), + sources.WithVerify(true), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) // Enumerate should always return an error if the reporter returns an @@ -280,7 +312,15 @@ func TestChunkUnitReporterErr(t *testing.T) { // Initialize the source. s := Source{} - err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test chunk unit"), + sources.WithVerify(true), + sources.WithConcurrency(1), + ), + ) assert.NoError(t, err) // Happy path. ChunkUnit should always return an error if the reporter diff --git a/pkg/sources/gcs/gcs.go b/pkg/sources/gcs/gcs.go index a94ff4f2e1f2..07da1a3c2587 100644 --- a/pkg/sources/gcs/gcs.go +++ b/pkg/sources/gcs/gcs.go @@ -112,22 +112,22 @@ func (c *persistableCache) shouldPersist() (bool, string) { } // Init returns an initialized GCS source. -func (s *Source) Init(aCtx context.Context, name string, id sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { +func (s *Source) Init(aCtx context.Context, cfg *sources.Config) error { s.log = aCtx.Logger() - s.name = name - s.verify = verify - s.sourceId = sourceID - s.jobId = id - s.concurrency = concurrency + s.name = cfg.Name + s.verify = cfg.Verify + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.concurrency = cfg.Concurrency var conn sourcespb.GCS - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}) if err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } - gcsManager, err := configureGCSManager(aCtx, &conn, concurrency) + gcsManager, err := configureGCSManager(aCtx, &conn, s.concurrency) if err != nil { return err } diff --git a/pkg/sources/gcs/gcs_integration_test.go b/pkg/sources/gcs/gcs_integration_test.go index 6c8fe2a09c46..8994a8190ad2 100644 --- a/pkg/sources/gcs/gcs_integration_test.go +++ b/pkg/sources/gcs/gcs_integration_test.go @@ -26,7 +26,15 @@ func TestChunks(t *testing.T) { ExcludeBuckets: []string{perfTestBucketGlob, publicBucket}, }) - err := source.Init(ctx, "test", 1, 1, true, conn, 8) + err := source.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test"), + sources.WithVerify(true), + sources.WithConcurrency(9), + ), + ) assert.Nil(t, err) chunksCh := make(chan *sources.Chunk, 1) @@ -66,7 +74,15 @@ func TestChunks_PublicBucket(t *testing.T) { IncludeBuckets: []string{publicBucket}, }) - err := source.Init(ctx, "test", 1, 1, true, conn, 8) + err := source.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test"), + sources.WithVerify(true), + sources.WithConcurrency(9), + ), + ) assert.Nil(t, err) chunksCh := make(chan *sources.Chunk, 1) diff --git a/pkg/sources/gcs/gcs_test.go b/pkg/sources/gcs/gcs_test.go index 126facfe40f5..0f6facf3af20 100644 --- a/pkg/sources/gcs/gcs_test.go +++ b/pkg/sources/gcs/gcs_test.go @@ -46,7 +46,15 @@ func TestSourceInit(t *testing.T) { Credential: &sourcespb.GCS_Unauthenticated{}, }) - err := source.Init(context.Background(), "test", 1, 1, true, conn, 8) + err := source.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test"), + sources.WithVerify(true), + sources.WithConcurrency(9), + ), + ) assert.Nil(t, err) assert.NotNil(t, source.gcsManager) } diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index d4577d28a302..32fc2fc3c844 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -145,17 +145,17 @@ func (s *Source) withScanOptions(scanOptions *ScanOptions) { } // Init returns an initialized GitHub source. -func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.name = name - s.sourceID = sourceId - s.jobID = jobId - s.verify = verify +func (s *Source) Init(aCtx context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceID = cfg.SourceID + s.jobID = cfg.JobID + s.verify = cfg.Verify if s.scanOptions == nil { s.scanOptions = &ScanOptions{} } var conn sourcespb.Git - if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil { + if err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}); err != nil { return fmt.Errorf("error unmarshalling connection: %w", err) } @@ -193,7 +193,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so s.conn = &conn - if concurrency == 0 { + concurrency := cfg.Concurrency + if cfg.Concurrency == 0 { concurrency = runtime.NumCPU() } @@ -201,7 +202,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so return err } - cfg := &Config{ + gitCfg := &Config{ SourceName: s.name, JobID: s.jobID, SourceID: s.sourceID, @@ -226,7 +227,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so }, UseCustomContentWriter: s.useCustomContentWriter, } - s.git = NewGit(cfg) + s.git = NewGit(gitCfg) return nil } diff --git a/pkg/sources/git/git_test.go b/pkg/sources/git/git_test.go index cc9d31756d3d..25b4e6d5da57 100644 --- a/pkg/sources/git/git_test.go +++ b/pkg/sources/git/git_test.go @@ -131,7 +131,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, tt.init.concurrency) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(tt.init.concurrency), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -226,7 +234,15 @@ func TestSource_Chunks_Integration(t *testing.T) { if err != nil { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(4), + sources.WithVerify(tt.init.verify), + ), + ) if err != nil { t.Fatal(err) } @@ -369,7 +385,15 @@ func TestSource_Chunks_Edge_Cases(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(4), + sources.WithVerify(tt.init.verify), + ), + ) if err != nil { t.Errorf("Source.Init() error = %v", err) return @@ -523,7 +547,15 @@ func TestEnumerate(t *testing.T) { // Initialize the source. s := Source{} - err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test enumerate"), + sources.WithConcurrency(1), + sources.WithVerify(true), + ), + ) assert.NoError(t, err) reporter := sourcestest.TestReporter{} @@ -552,7 +584,15 @@ func TestChunkUnit(t *testing.T) { Credential: &sourcespb.Git_Unauthenticated{}, }) assert.NoError(t, err) - err = s.Init(ctx, "test chunk", 0, 0, true, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test chunk"), + sources.WithConcurrency(1), + sources.WithVerify(true), + ), + ) assert.NoError(t, err) reporter := sourcestest.TestReporter{} diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 2891c6ef34fa..39cca3491ba4 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -208,21 +208,22 @@ func (c *filteredRepoCache) includeRepo(s string) bool { } // Init returns an initialized GitHub source. -func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { +func (s *Source) Init(aCtx context.Context, cfg *sources.Config) error { s.log = aCtx.Logger() - s.name = name - s.sourceID = sourceID - s.jobID = jobID - s.verify = verify - s.jobPool = &errgroup.Group{} + s.name = cfg.Name + s.sourceID = cfg.SourceID + s.jobID = cfg.JobID + s.verify = cfg.Verify + s.jobPool = new(errgroup.Group) + concurrency := int(cfg.Concurrency) s.jobPool.SetLimit(concurrency) s.httpClient = common.RetryableHttpClientTimeout(60) s.apiClient = github.NewClient(s.httpClient) var conn sourcespb.GitHub - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}) if err != nil { return fmt.Errorf("error unmarshalling connection: %w", err) } @@ -267,7 +268,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so s.publicMap = map[string]source_metadatapb.Visibility{} - cfg := &git.Config{ + gitCfg := &git.Config{ SourceName: s.name, JobID: s.jobID, SourceID: s.sourceID, @@ -294,7 +295,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so }, UseCustomContentWriter: s.useCustomContentWriter, } - s.git = git.NewGit(cfg) + s.git = git.NewGit(gitCfg) return nil } diff --git a/pkg/sources/github/github_integration_test.go b/pkg/sources/github/github_integration_test.go index d76b5b3f922c..7c7808122c84 100644 --- a/pkg/sources/github/github_integration_test.go +++ b/pkg/sources/github/github_integration_test.go @@ -180,7 +180,15 @@ func TestSource_ScanComments(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(4), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -275,7 +283,15 @@ func TestSource_ScanChunks(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(8), + sources.WithVerify(tt.init.verify), + ), + ) assert.Nil(t, err) chunksCh := make(chan *sources.Chunk, 1) @@ -596,7 +612,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(4), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -737,7 +761,15 @@ func TestSource_paginateGists(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(4), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -942,7 +974,15 @@ func TestSource_Chunks_TargetedScan(t *testing.T) { conn, err := anypb.New(tt.init.connection) assert.Nil(t, err) - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(8), + sources.WithVerify(tt.init.verify), + ), + ) assert.Nil(t, err) chunksCh := make(chan *sources.Chunk, 1) diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 5c202fa9093c..66cd44f0ebce 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -41,9 +41,19 @@ func createTestSource(src *sourcespb.GitHub) (*Source, *anypb.Any) { func initTestSource(src *sourcespb.GitHub) *Source { s, conn := createTestSource(src) - if err := s.Init(context.Background(), "test - github", 0, 1337, false, conn, 1); err != nil { + + err := s.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test - github"), + sources.WithConcurrency(1), + ), + ) + if err != nil { panic(err) } + s.apiClient = github.NewClient(s.httpClient) gock.InterceptClient(s.httpClient) return s @@ -57,7 +67,14 @@ func TestInit(t *testing.T) { }, }) - err := source.Init(context.Background(), "test - github", 0, 1337, false, conn, 1) + err := source.Init( + context.Background(), + sources.NewConfig( + conn, + sources.WithName("test - github"), + sources.WithConcurrency(1), + ), + ) assert.Nil(t, err) // TODO: test error case diff --git a/pkg/sources/gitlab/gitlab.go b/pkg/sources/gitlab/gitlab.go index 62ff13db1823..c48e535e48bc 100644 --- a/pkg/sources/gitlab/gitlab.go +++ b/pkg/sources/gitlab/gitlab.go @@ -78,16 +78,17 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized Gitlab source. -func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.name = name - s.sourceID = sourceId - s.jobID = jobId - s.verify = verify - s.jobPool = &errgroup.Group{} +func (s *Source) Init(_ context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceID = cfg.SourceID + s.jobID = cfg.JobID + s.verify = cfg.Verify + s.jobPool = new(errgroup.Group) + concurrency := cfg.Concurrency s.jobPool.SetLimit(concurrency) var conn sourcespb.GitLab - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}) if err != nil { return fmt.Errorf("error unmarshalling connection: %w", err) } @@ -114,7 +115,7 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc // We may need the password as a token if the user is using an access_token with basic auth. s.token = cred.BasicAuth.Password default: - return fmt.Errorf("invalid configuration given for source %q (%s)", name, s.Type().String()) + return fmt.Errorf("invalid configuration given for source %q (%s)", s.name, s.Type().String()) } if len(s.url) == 0 { @@ -127,7 +128,7 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc return err } - cfg := &git.Config{ + gitCfg := &git.Config{ SourceName: s.name, JobID: s.jobID, SourceID: s.sourceID, @@ -153,7 +154,7 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc }, UseCustomContentWriter: s.useCustomContentWriter, } - s.git = git.NewGit(cfg) + s.git = git.NewGit(gitCfg) return nil } diff --git a/pkg/sources/gitlab/gitlab_test.go b/pkg/sources/gitlab/gitlab_test.go index 78b6a5ec4ed6..4a7ae197e9d1 100644 --- a/pkg/sources/gitlab/gitlab_test.go +++ b/pkg/sources/gitlab/gitlab_test.go @@ -144,7 +144,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithConcurrency(10), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return @@ -309,7 +317,14 @@ func TestSource_Validate(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.name, 0, 0, false, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.name), + sources.WithConcurrency(1), + ), + ) if err != nil { t.Fatalf("Source.Init() error: %v", err) } diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 8ac45b1e66ba..e5ef0e90e87c 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -73,21 +73,21 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized AWS source -func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.log = context.WithValues(aCtx, "source", s.Type(), "name", name).Logger() - - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify - s.concurrency = concurrency +func (s *Source) Init(aCtx context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify + s.concurrency = cfg.Concurrency s.errorCount = &sync.Map{} s.log = aCtx.Logger() - s.jobPool = &errgroup.Group{} - s.jobPool.SetLimit(concurrency) + s.jobPool = new(errgroup.Group) + s.jobPool.SetLimit(s.concurrency) + + s.log = context.WithValues(aCtx, "source", s.Type(), "name", s.name).Logger() var conn sourcespb.S3 - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}) if err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index 12d72376c9c6..6d50f8b17baa 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -10,9 +10,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" - "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" @@ -33,7 +34,14 @@ func TestSource_ChunksCount(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, "test name", 0, 0, false, conn, 1) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName("test name"), + sources.WithConcurrency(1), + ), + ) chunksCh := make(chan *sources.Chunk) go func() { defer close(chunksCh) @@ -159,7 +167,14 @@ func TestSource_Validate(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.name, 0, 0, false, conn, 0) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.name), + sources.WithConcurrency(1), + ), + ) if err != nil { t.Fatal(err) } diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index c4e3fc75f399..cf2c4605ae68 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -10,12 +10,13 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" - "google.golang.org/protobuf/types/known/anypb" ) func TestSource_Chunks(t *testing.T) { @@ -31,7 +32,6 @@ func TestSource_Chunks(t *testing.T) { s3secret := secret.MustGetField("AWS_S3_SECRET") type init struct { - name string verify bool connection *sourcespb.S3 setEnv map[string]string @@ -92,7 +92,15 @@ func TestSource_Chunks(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.name), + sources.WithConcurrency(8), + sources.WithVerify(tt.init.verify), + ), + ) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/sources/source_manager_test.go b/pkg/sources/source_manager_test.go index 815f0507e5bc..f9617014af2e 100644 --- a/pkg/sources/source_manager_test.go +++ b/pkg/sources/source_manager_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" @@ -25,9 +24,9 @@ type DummySource struct { func (d *DummySource) Type() sourcespb.SourceType { return 1337 } func (d *DummySource) SourceID() SourceID { return d.sourceID } func (d *DummySource) JobID() JobID { return d.jobID } -func (d *DummySource) Init(_ context.Context, _ string, jobID JobID, sourceID SourceID, _ bool, _ *anypb.Any, _ int) error { - d.sourceID = sourceID - d.jobID = jobID +func (d *DummySource) Init(_ context.Context, cfg *Config) error { + d.sourceID = cfg.SourceID + d.jobID = cfg.JobID return nil } func (d *DummySource) GetProgress() *Progress { return nil } @@ -85,7 +84,18 @@ func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) erro // buildDummy is a helper function to enroll a DummySource with a SourceManager. func buildDummy(chunkMethod chunker) (Source, error) { source := &DummySource{chunker: chunkMethod} - if err := source.Init(context.Background(), "dummy", 123, 456, true, nil, 42); err != nil { + err := source.Init( + context.Background(), + NewConfig( + nil, + WithName("dummy"), + WithSourceID(123), + WithJobID(456), + WithVerify(true), + WithConcurrency(42), + ), + ) + if err != nil { return nil, err } return source, nil diff --git a/pkg/sources/sources.go b/pkg/sources/sources.go index 6d895887101a..6993c60fb5a7 100644 --- a/pkg/sources/sources.go +++ b/pkg/sources/sources.go @@ -1,6 +1,7 @@ package sources import ( + "runtime" "sync" "google.golang.org/protobuf/types/known/anypb" @@ -57,6 +58,66 @@ type ChunkingTarget struct { SecretID int64 } +// SourceConfigOption is responsible for configuring a Config. +// This allows for a more flexible way to configure a Source. +type SourceConfigOption func(*Config) + +// WithName sets the name for the source. +func WithName(name string) SourceConfigOption { + return func(c *Config) { c.Name = name } +} + +// WithJobID sets the job ID for the source. +func WithJobID(jobID JobID) SourceConfigOption { + return func(c *Config) { c.JobID = jobID } +} + +// WithSourceID sets the source ID for the source. +func WithSourceID(sourceID SourceID) SourceConfigOption { + return func(c *Config) { c.SourceID = sourceID } +} + +// WithConcurrency sets the concurrency for the source. +func WithConcurrency(concurrency int) SourceConfigOption { + return func(c *Config) { c.Concurrency = concurrency } +} + +// WithVerify sets the verify flag for the source. +func WithVerify(verify bool) SourceConfigOption { + return func(c *Config) { c.Verify = verify } +} + +// Config provides a way to configure a Source. +type Config struct { + // Name is the name of the source. + Name string + // JobID is the ID of the job that the source is associated with. + JobID JobID + // SourceID is the ID of the source. + SourceID SourceID + // Verify specifies whether any secrets in the Chunk should be verified. + Verify bool + // Connection is the connection information for the source. + Connection *anypb.Any + // Concurrency is the number of concurrent workers to use for the source. + Concurrency int +} + +// NewConfig creates a new Config with the given connection and options. +// The connection is the only required parameter because a source cannot be initialized without it. +func NewConfig(connection *anypb.Any, opts ...SourceConfigOption) *Config { + cfg := &Config{Connection: connection} + + for _, opt := range opts { + opt(cfg) + } + if cfg.Concurrency == 0 { + cfg.Concurrency = runtime.NumCPU() + } + + return cfg +} + // Source defines the interface required to implement a source chunker. type Source interface { // Type returns the source type, used for matching against configuration and jobs. @@ -66,7 +127,7 @@ type Source interface { // JobID returns the initialized job ID used for tracking relationships in the DB. JobID() JobID // Init initializes the source. - Init(aCtx context.Context, name string, jobId JobID, sourceId SourceID, verify bool, connection *anypb.Any, concurrency int) error + Init(aCtx context.Context, cfg *Config) error // Chunks emits data over a channel which is then decoded and scanned for secrets. // By default, data is obtained indiscriminately. However, by providing one or more // ChunkingTarget parameters, the caller can direct the function to retrieve diff --git a/pkg/sources/syslog/syslog.go b/pkg/sources/syslog/syslog.go index 48b607edd09a..11e73cff43e0 100644 --- a/pkg/sources/syslog/syslog.go +++ b/pkg/sources/syslog/syslog.go @@ -123,15 +123,14 @@ func (s *Source) InjectConnection(conn *sourcespb.Syslog) { } // Init returns an initialized Syslog source. -func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify +func (s *Source) Init(_ context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify var conn sourcespb.Syslog - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}) if err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } diff --git a/pkg/sources/travisci/travisci.go b/pkg/sources/travisci/travisci.go index cfe457dbc1cc..99b8fa6b6916 100644 --- a/pkg/sources/travisci/travisci.go +++ b/pkg/sources/travisci/travisci.go @@ -56,16 +56,16 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized TravisCI source. -func (s *Source) Init(ctx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { - s.name = name - s.sourceId = sourceId - s.jobId = jobId - s.verify = verify - s.jobPool = &errgroup.Group{} - s.jobPool.SetLimit(concurrency) +func (s *Source) Init(ctx context.Context, cfg *sources.Config) error { + s.name = cfg.Name + s.sourceId = cfg.SourceID + s.jobId = cfg.JobID + s.verify = cfg.Verify + s.jobPool = new(errgroup.Group) + s.jobPool.SetLimit(cfg.Concurrency) var conn sourcespb.TravisCI - if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil { + if err := anypb.UnmarshalTo(cfg.Connection, &conn, proto.UnmarshalOptions{}); err != nil { return errors.WrapPrefix(err, "error unmarshalling connection", 0) } diff --git a/pkg/sources/travisci/travisci_test.go b/pkg/sources/travisci/travisci_test.go index d00f9663a46f..93ebf894a1dd 100644 --- a/pkg/sources/travisci/travisci_test.go +++ b/pkg/sources/travisci/travisci_test.go @@ -11,6 +11,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest" ) @@ -69,7 +70,15 @@ func TestSource_Scan(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5) + err = s.Init( + ctx, + sources.NewConfig( + conn, + sources.WithName(tt.init.name), + sources.WithVerify(tt.init.verify), + sources.WithConcurrency(5), + ), + ) if (err != nil) != tt.wantErr { t.Fatalf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) }