From 4a4bb899924ea55634498fcc2c01a0b32fbf7ac9 Mon Sep 17 00:00:00 2001 From: Miccah Date: Fri, 9 Feb 2024 12:30:28 -0800 Subject: [PATCH] Add flag to write job reports to disk (#2298) * Add flag to write job reports to disk * Fix nil pointer / non-nil interface bug * Synchronize job report writer goroutine * Log when the report has been written --- main.go | 7 +++++ pkg/common/export_error.go | 16 ++++++++++ pkg/engine/engine.go | 64 ++++++++++++++++++++++++++++++++------ 3 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 pkg/common/export_error.go diff --git a/main.go b/main.go index b34f9083d5d4..b9e7ba4ef139 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io" "net/http" _ "net/http/pprof" "os" @@ -64,6 +65,7 @@ var ( archiveTimeout = cli.Flag("archive-timeout", "Maximum time to spend extracting an archive.").Duration() includeDetectors = cli.Flag("include-detectors", "Comma separated list of detector types to include. Protobuf name or IDs may be used, as well as ranges.").Default("all").String() excludeDetectors = cli.Flag("exclude-detectors", "Comma separated list of detector types to exclude. Protobuf name or IDs may be used, as well as ranges. IDs defined here take precedence over the include list.").String() + jobReportFile = cli.Flag("output-report", "Write a scan report to the provided path.").Hidden().OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) gitScan = cli.Command("git", "Find credentials in git repositories.") gitScanURI = gitScan.Arg("uri", "Git repository URL. https://, file://, or ssh:// schema expected.").Required().String() @@ -398,6 +400,10 @@ func run(state overseer.State) { fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n") } + var jobReportWriter io.WriteCloser + if *jobReportFile != nil { + jobReportWriter = *jobReportFile + } e, err := engine.Start(ctx, engine.WithConcurrency(uint8(*concurrency)), engine.WithDecoders(decoders.DefaultDecoders()...), @@ -413,6 +419,7 @@ func run(state overseer.State) { engine.WithPrinter(printer), engine.WithFilterEntropy(*filterEntropy), engine.WithVerificationOverlap(*allowVerificationOverlap), + engine.WithJobReportWriter(jobReportWriter), ) if err != nil { logFatal(err, "error initializing engine") diff --git a/pkg/common/export_error.go b/pkg/common/export_error.go new file mode 100644 index 000000000000..68baab3e5604 --- /dev/null +++ b/pkg/common/export_error.go @@ -0,0 +1,16 @@ +package common + +// ExportError is an implementation of error that can be JSON marshalled. It +// must be a public exported type for this reason. +type ExportError string + +func (e ExportError) Error() string { return string(e) } + +// ExportErrors converts a list of errors into []ExportError. +func ExportErrors(errs ...error) []error { + output := make([]error, 0, len(errs)) + for _, err := range errs { + output = append(output, ExportError(err.Error())) + } + return output +} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 3d822e4bb2e2..31d813545853 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -2,8 +2,10 @@ package engine import ( "bytes" + "encoding/json" "errors" "fmt" + "io" "runtime" "sync" "sync/atomic" @@ -58,9 +60,10 @@ type Printer interface { type Engine struct { // CLI flags. - concurrency uint8 - decoders []decoders.Decoder - detectors []detectors.Detector + concurrency uint8 + decoders []decoders.Decoder + detectors []detectors.Detector + jobReportWriter io.WriteCloser // filterUnverified is used to reduce the number of unverified results. // If there are multiple unverified results for the same chunk for the same detector, // only the first one will be kept. @@ -119,6 +122,12 @@ func (r *verificationOverlapTracker) increment() { // Option is used to configure the engine during initialization using functional options. type Option func(*Engine) +func WithJobReportWriter(w io.WriteCloser) Option { + return func(e *Engine) { + e.jobReportWriter = w + } +} + func WithConcurrency(concurrency uint8) Option { return func(e *Engine) { e.concurrency = concurrency @@ -317,6 +326,7 @@ func Start(ctx context.Context, options ...Option) (*Engine, error) { if err := e.initialize(ctx, options...); err != nil { return nil, err } + e.initSourceManager(ctx) e.setDefaults(ctx) e.sanityChecks(ctx) e.startWorkers(ctx) @@ -373,6 +383,47 @@ func (e *Engine) initialize(ctx context.Context, options ...Option) error { return nil } +func (e *Engine) initSourceManager(ctx context.Context) { + opts := []func(*sources.SourceManager){ + sources.WithConcurrentSources(int(e.concurrency)), + sources.WithConcurrentUnits(int(e.concurrency)), + sources.WithSourceUnits(), + sources.WithBufferedOutput(defaultChannelBuffer), + } + if e.jobReportWriter != nil { + unitHook, finishedMetrics := sources.NewUnitHook(ctx) + opts = append(opts, sources.WithReportHook(unitHook)) + e.wgDetectorWorkers.Add(1) + go func() { + defer e.wgDetectorWorkers.Done() + defer func() { + e.jobReportWriter.Close() + // Add a bit of extra information if it's a *os.File. + if namer, ok := e.jobReportWriter.(interface{ Name() string }); ok { + ctx.Logger().Info("report written", "path", namer.Name()) + } else { + ctx.Logger().Info("report written") + } + }() + for metrics := range finishedMetrics { + metrics.Errors = common.ExportErrors(metrics.Errors...) + details, err := json.Marshal(map[string]any{ + "version": 1, + "data": metrics, + }) + if err != nil { + ctx.Logger().Error(err, "error marshalling job details") + continue + } + if _, err := e.jobReportWriter.Write(append(details, '\n')); err != nil { + ctx.Logger().Error(err, "error writing to file") + } + } + }() + } + e.sourceManager = sources.NewManager(opts...) +} + // setDefaults ensures that if specific engine properties aren't provided, // they're set to reasonable default values. It makes the engine robust to // incomplete configuration. @@ -384,13 +435,6 @@ func (e *Engine) setDefaults(ctx context.Context) { } ctx.Logger().V(3).Info("engine started", "workers", e.concurrency) - e.sourceManager = sources.NewManager( - sources.WithConcurrentSources(int(e.concurrency)), - sources.WithConcurrentUnits(int(e.concurrency)), - sources.WithSourceUnits(), - sources.WithBufferedOutput(defaultChannelBuffer), - ) - // Default decoders handle common encoding formats. if len(e.decoders) == 0 { e.decoders = decoders.DefaultDecoders()