From f3630da1e0b56063f243259b70b2572c98a5ea96 Mon Sep 17 00:00:00 2001 From: Dustin Decker Date: Thu, 26 Sep 2024 10:17:47 -0700 Subject: [PATCH] Improve process cleanup (#3339) * ensures that cmd.Wait() is always called, even if there's a panic in the FromReader function or if stdOut.Close() returns an error * close stdout and ensure wait is called when handling binaries * process cleanup improvements * lint --- pkg/gitparse/gitparse.go | 8 +++--- pkg/gitparse/gitparse_test.go | 10 ++++++++ pkg/process/zombies.go | 48 +++++++++++++++++++++++++++++++++++ pkg/sources/git/git.go | 26 ++++++++++++++----- pkg/sources/git/git_test.go | 9 +++++++ 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 pkg/process/zombies.go diff --git a/pkg/gitparse/gitparse.go b/pkg/gitparse/gitparse.go index 7a827808fcf8..e05f04ec5068 100644 --- a/pkg/gitparse/gitparse.go +++ b/pkg/gitparse/gitparse.go @@ -314,13 +314,15 @@ func (c *Parser) executeCommand(ctx context.Context, cmd *exec.Cmd, isStaged boo }() go func() { + defer func() { + if err := cmd.Wait(); err != nil { + ctx.Logger().V(2).Info("Error waiting for git command to complete.", "error", err) + } + }() c.FromReader(ctx, stdOut, diffChan, isStaged) if err := stdOut.Close(); err != nil { ctx.Logger().V(2).Info("Error closing git stdout pipe.", "error", err) } - if err := cmd.Wait(); err != nil { - ctx.Logger().V(2).Info("Error waiting for git command to complete.", "error", err) - } }() return diffChan, nil diff --git a/pkg/gitparse/gitparse_test.go b/pkg/gitparse/gitparse_test.go index f5ca0d57961e..686637b1c095 100644 --- a/pkg/gitparse/gitparse_test.go +++ b/pkg/gitparse/gitparse_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/process" bufferwriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffer_writer" bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer" ) @@ -791,6 +792,8 @@ func (d1 *Diff) Equal(ctx context.Context, d2 *Diff) bool { func TestCommitParsing(t *testing.T) { expected := expectedDiffs() + beforeProcesses := process.GetGitProcessList() + r := bytes.NewReader([]byte(commitLog)) diffChan := make(chan *Diff) parser := NewParser() @@ -809,6 +812,13 @@ func TestCommitParsing(t *testing.T) { } i++ } + + afterProcesses := process.GetGitProcessList() + zombies := process.DetectGitZombies(beforeProcesses, afterProcesses) + + if len(zombies) > 0 { + t.Errorf("Detected %d zombie git processes: %v", len(zombies), zombies) + } } func newBufferedFileWriterWithContent(content []byte) *bufferedfilewriter.BufferedFileWriter { diff --git a/pkg/process/zombies.go b/pkg/process/zombies.go new file mode 100644 index 000000000000..964c9299ffee --- /dev/null +++ b/pkg/process/zombies.go @@ -0,0 +1,48 @@ +package process + +import ( + "os/exec" + "runtime" + "strings" +) + +func GetGitProcessList() []string { + var cmd *exec.Cmd + if runtime.GOOS == "darwin" { + cmd = exec.Command("ps", "-eo", "pid,state,command") + } else { + cmd = exec.Command("ps", "-eo", "pid,stat,cmd") + } + + output, err := cmd.Output() + if err != nil { + return nil + } + + lines := strings.Split(string(output), "\n") + var gitProcesses []string + for _, line := range lines { + if strings.Contains(line, "git") { + gitProcesses = append(gitProcesses, line) + } + } + return gitProcesses +} + +func DetectGitZombies(before, after []string) []string { + beforeMap := make(map[string]bool) + for _, process := range before { + beforeMap[process] = true + } + + var zombies []string + for _, process := range after { + if !beforeMap[process] { + fields := strings.Fields(process) + if len(fields) >= 2 && (fields[1] == "Z" || strings.HasPrefix(fields[1], "Z")) { + zombies = append(zombies, process) + } + } + } + return zombies +} diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index 0a83664d8305..5843f49b83a2 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -1241,32 +1241,44 @@ func (s *Git) handleBinary(ctx context.Context, gitDir string, reporter sources. } cmd := exec.Command("git", "-C", gitDir, "cat-file", "blob", commitHash.String()+":"+path) - stdout, err := s.executeCatFileCmd(cmd) + stdout, catCmd, err := s.executeCatFileCmd(cmd) if err != nil { return err } + // Wait must be called after closing the pipe (defer is a stack, so first defer is executed last) + defer func() { + _ = catCmd.Wait() + }() + defer stdout.Close() + + err = handlers.HandleFile(ctx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives)) - if err = handlers.HandleFile(ctx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives)); err != nil { + // Always call Wait() to ensure the process is properly cleaned up + waitErr := cmd.Wait() + + // If there was an error in HandleFile, return that error + if err != nil { return err } - return cmd.Wait() + // If Wait() resulted in an error, return that error + return waitErr } -func (s *Git) executeCatFileCmd(cmd *exec.Cmd) (io.ReadCloser, error) { +func (s *Git) executeCatFileCmd(cmd *exec.Cmd) (io.ReadCloser, *exec.Cmd, error) { var stderr bytes.Buffer cmd.Stderr = &stderr stdout, err := cmd.StdoutPipe() if err != nil { - return nil, fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes()) + return nil, nil, fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes()) } if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes()) + return nil, nil, fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes()) } - return stdout, nil + return stdout, cmd, nil } func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error { diff --git a/pkg/sources/git/git_test.go b/pkg/sources/git/git_test.go index 4b41e61fff6d..7b0991fd0188 100644 --- a/pkg/sources/git/git_test.go +++ b/pkg/sources/git/git_test.go @@ -15,6 +15,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/process" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest" ) @@ -235,6 +236,8 @@ func TestSource_Chunks_Integration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { s := Source{} + beforeProcesses := process.GetGitProcessList() + conn, err := anypb.New(tt.init.connection) if err != nil { t.Fatal(err) @@ -280,6 +283,12 @@ func TestSource_Chunks_Integration(t *testing.T) { } } + + afterProcesses := process.GetGitProcessList() + zombies := process.DetectGitZombies(beforeProcesses, afterProcesses) + if len(zombies) > 0 { + t.Errorf("Git zombies detected: %v", zombies) + } }) } }