Skip to content

Commit

Permalink
Add flag to write job reports to disk (#2298)
Browse files Browse the repository at this point in the history
* Add flag to write job reports to disk

* Fix nil pointer / non-nil interface bug

* Synchronize job report writer goroutine

* Log when the report has been written
  • Loading branch information
mcastorina authored and ahrav committed Feb 11, 2024
1 parent 698a20b commit 4a4bb89
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 10 deletions.
7 changes: 7 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"io"
"net/http"
_ "net/http/pprof"
"os"
Expand Down Expand Up @@ -64,6 +65,7 @@ var (
archiveTimeout = cli.Flag("archive-timeout", "Maximum time to spend extracting an archive.").Duration()
includeDetectors = cli.Flag("include-detectors", "Comma separated list of detector types to include. Protobuf name or IDs may be used, as well as ranges.").Default("all").String()
excludeDetectors = cli.Flag("exclude-detectors", "Comma separated list of detector types to exclude. Protobuf name or IDs may be used, as well as ranges. IDs defined here take precedence over the include list.").String()
jobReportFile = cli.Flag("output-report", "Write a scan report to the provided path.").Hidden().OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)

gitScan = cli.Command("git", "Find credentials in git repositories.")
gitScanURI = gitScan.Arg("uri", "Git repository URL. https://, file://, or ssh:// schema expected.").Required().String()
Expand Down Expand Up @@ -398,6 +400,10 @@ func run(state overseer.State) {
fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n")
}

var jobReportWriter io.WriteCloser
if *jobReportFile != nil {
jobReportWriter = *jobReportFile
}
e, err := engine.Start(ctx,
engine.WithConcurrency(uint8(*concurrency)),
engine.WithDecoders(decoders.DefaultDecoders()...),
Expand All @@ -413,6 +419,7 @@ func run(state overseer.State) {
engine.WithPrinter(printer),
engine.WithFilterEntropy(*filterEntropy),
engine.WithVerificationOverlap(*allowVerificationOverlap),
engine.WithJobReportWriter(jobReportWriter),
)
if err != nil {
logFatal(err, "error initializing engine")
Expand Down
16 changes: 16 additions & 0 deletions pkg/common/export_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package common

// ExportError is an implementation of error that can be JSON marshalled. It
// must be a public exported type for this reason.
type ExportError string

func (e ExportError) Error() string { return string(e) }

// ExportErrors converts a list of errors into []ExportError.
func ExportErrors(errs ...error) []error {
output := make([]error, 0, len(errs))
for _, err := range errs {
output = append(output, ExportError(err.Error()))
}
return output
}
64 changes: 54 additions & 10 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package engine

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -58,9 +60,10 @@ type Printer interface {

type Engine struct {
// CLI flags.
concurrency uint8
decoders []decoders.Decoder
detectors []detectors.Detector
concurrency uint8
decoders []decoders.Decoder
detectors []detectors.Detector
jobReportWriter io.WriteCloser
// filterUnverified is used to reduce the number of unverified results.
// If there are multiple unverified results for the same chunk for the same detector,
// only the first one will be kept.
Expand Down Expand Up @@ -119,6 +122,12 @@ func (r *verificationOverlapTracker) increment() {
// Option is used to configure the engine during initialization using functional options.
type Option func(*Engine)

func WithJobReportWriter(w io.WriteCloser) Option {
return func(e *Engine) {
e.jobReportWriter = w
}
}

func WithConcurrency(concurrency uint8) Option {
return func(e *Engine) {
e.concurrency = concurrency
Expand Down Expand Up @@ -317,6 +326,7 @@ func Start(ctx context.Context, options ...Option) (*Engine, error) {
if err := e.initialize(ctx, options...); err != nil {
return nil, err
}
e.initSourceManager(ctx)
e.setDefaults(ctx)
e.sanityChecks(ctx)
e.startWorkers(ctx)
Expand Down Expand Up @@ -373,6 +383,47 @@ func (e *Engine) initialize(ctx context.Context, options ...Option) error {
return nil
}

func (e *Engine) initSourceManager(ctx context.Context) {
opts := []func(*sources.SourceManager){
sources.WithConcurrentSources(int(e.concurrency)),
sources.WithConcurrentUnits(int(e.concurrency)),
sources.WithSourceUnits(),
sources.WithBufferedOutput(defaultChannelBuffer),
}
if e.jobReportWriter != nil {
unitHook, finishedMetrics := sources.NewUnitHook(ctx)
opts = append(opts, sources.WithReportHook(unitHook))
e.wgDetectorWorkers.Add(1)
go func() {
defer e.wgDetectorWorkers.Done()
defer func() {
e.jobReportWriter.Close()
// Add a bit of extra information if it's a *os.File.
if namer, ok := e.jobReportWriter.(interface{ Name() string }); ok {
ctx.Logger().Info("report written", "path", namer.Name())
} else {
ctx.Logger().Info("report written")
}
}()
for metrics := range finishedMetrics {
metrics.Errors = common.ExportErrors(metrics.Errors...)
details, err := json.Marshal(map[string]any{
"version": 1,
"data": metrics,
})
if err != nil {
ctx.Logger().Error(err, "error marshalling job details")
continue
}
if _, err := e.jobReportWriter.Write(append(details, '\n')); err != nil {
ctx.Logger().Error(err, "error writing to file")
}
}
}()
}
e.sourceManager = sources.NewManager(opts...)
}

// setDefaults ensures that if specific engine properties aren't provided,
// they're set to reasonable default values. It makes the engine robust to
// incomplete configuration.
Expand All @@ -384,13 +435,6 @@ func (e *Engine) setDefaults(ctx context.Context) {
}
ctx.Logger().V(3).Info("engine started", "workers", e.concurrency)

e.sourceManager = sources.NewManager(
sources.WithConcurrentSources(int(e.concurrency)),
sources.WithConcurrentUnits(int(e.concurrency)),
sources.WithSourceUnits(),
sources.WithBufferedOutput(defaultChannelBuffer),
)

// Default decoders handle common encoding formats.
if len(e.decoders) == 0 {
e.decoders = decoders.DefaultDecoders()
Expand Down

0 comments on commit 4a4bb89

Please sign in to comment.