Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve --err-first-hit handling #596

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
Expand Down Expand Up @@ -80,6 +81,16 @@ var riskMap = map[string]int{
"critical": 4,
}

func showError(err error) {
emoji := "πŸ’£"
if errors.Is(err, action.ErrMatchedCondition) {
emoji = "πŸ‘‹"
err = errors.Unwrap(err)
}

fmt.Fprintf(os.Stderr, "%s %s\n", emoji, err.Error())
}

//nolint:cyclop // ignore complexity of 40
func main() {
returnCode := ExitOK
Expand Down Expand Up @@ -398,7 +409,7 @@ func main() {
ps, err := action.ActiveProcesses(ctx)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("process paths: %w", err)
return err
}
for _, p := range ps {
// in the future, we'll also want to attach process info directly
Expand All @@ -409,7 +420,7 @@ func main() {
res, err = action.Scan(ctx, mc)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("scan: %w", err)
return err
}

err = renderer.Full(ctx, res)
Expand Down Expand Up @@ -530,7 +541,13 @@ func main() {
}

if err := app.Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
returnCode = ExitActionFailed
if returnCode != 0 {
returnCode = ExitActionFailed
}
if errors.Is(err, action.ErrMatchedCondition) {
returnCode = ExitOK
}

showError(err)
}
}
44 changes: 24 additions & 20 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package action

import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
Expand All @@ -31,7 +32,8 @@ var (
// compiledRuleCache are a cache of previously compiled rules.
compiledRuleCache *yara.Rules
// compileOnce ensures that we compile rules only once even across threads.
compileOnce sync.Once
compileOnce sync.Once
ErrMatchedCondition = errors.New("matched requested condition")
)

// findFilesRecursively returns a list of files found recursively within a path.
Expand Down Expand Up @@ -233,7 +235,6 @@ func cachedRules(ctx context.Context, fss []fs.FS) (*yara.Rules, error) {
//nolint:gocognit,cyclop // ignoring complexity of 101,38
func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) {
logger := clog.FromContext(ctx)
logger.Debug("recursive scan", slog.Any("config", c))
r := &malcontent.Report{
Files: orderedmap.New[string, *malcontent.FileReport](),
}
Expand All @@ -243,11 +244,12 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report

var scanPathFindings sync.Map

var waitErr error

for _, scanPath := range c.ScanPaths {
if c.Renderer != nil {
c.Renderer.Scanning(ctx, scanPath)
}
logger.Debug("recursive scan", slog.Any("scanPath", scanPath))
imageURI := ""
ociExtractPath := ""
var err error
Expand Down Expand Up @@ -323,18 +325,19 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
fr, err := processFile(ctx, c, c.RuleFS, path, scanPath, trimPath, logger)
if err != nil {
scanPathFindings.Store(path, &malcontent.FileReport{})
return err
return fmt.Errorf("process: %w", err)
}
if fr != nil {
scanPathFindings.Store(path, fr)
if !c.OCI {
var frMap sync.Map
frMap.Store(path, fr)
if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil {
logger.Debugf("match short circuit: %s", err)
scanPathFindings.Store(path, fr)
return err
}
if fr == nil {
return nil
}

scanPathFindings.Store(path, fr)
if !c.OCI {
var frMap sync.Map
frMap.Store(path, fr)
if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil {
scanPathFindings.Store(path, fr)
return fmt.Errorf("%q: %w", path, ErrMatchedCondition)
}
}
return nil
Expand All @@ -351,8 +354,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
}

if err := g.Wait(); err != nil {
logger.Errorf("error with processing %v\n", err)
return nil, err
waitErr = err
}

var pathKeys []string
Expand Down Expand Up @@ -396,6 +398,11 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
}
}
}

// short-circuit out
if waitErr != nil {
return r, waitErr
}
} // loop: next scan path
return r, nil
}
Expand Down Expand Up @@ -460,9 +467,6 @@ func processFile(ctx context.Context, c malcontent.Config, ruleFS []fs.FS, path
func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) {
r, err := recursiveScan(ctx, c)
if err != nil {
if strings.Contains(err.Error(), "no matching capabilities") {
return r, nil
}
return r, err
}
for files := r.Files.Oldest(); files != nil; files = files.Next() {
Expand All @@ -473,7 +477,7 @@ func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error)
if c.Stats {
err = render.Statistics(r)
if err != nil {
return r, err
return r, fmt.Errorf("stats: %w", err)
}
}
return r, nil
Expand Down