Skip to content

Commit

Permalink
Improve process cleanup (#3339)
Browse files Browse the repository at this point in the history
* ensures that cmd.Wait() is always called, even if there's a panic in the FromReader function or if stdOut.Close() returns an error

* close stdout and ensure wait is called when handling binaries

* process cleanup improvements

* lint
  • Loading branch information
dustin-decker authored Sep 26, 2024
1 parent 6d022e7 commit f3630da
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 10 deletions.
8 changes: 5 additions & 3 deletions pkg/gitparse/gitparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,15 @@ func (c *Parser) executeCommand(ctx context.Context, cmd *exec.Cmd, isStaged boo
}()

go func() {
defer func() {
if err := cmd.Wait(); err != nil {
ctx.Logger().V(2).Info("Error waiting for git command to complete.", "error", err)
}
}()
c.FromReader(ctx, stdOut, diffChan, isStaged)
if err := stdOut.Close(); err != nil {
ctx.Logger().V(2).Info("Error closing git stdout pipe.", "error", err)
}
if err := cmd.Wait(); err != nil {
ctx.Logger().V(2).Info("Error waiting for git command to complete.", "error", err)
}
}()

return diffChan, nil
Expand Down
10 changes: 10 additions & 0 deletions pkg/gitparse/gitparse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/process"
bufferwriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffer_writer"
bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer"
)
Expand Down Expand Up @@ -791,6 +792,8 @@ func (d1 *Diff) Equal(ctx context.Context, d2 *Diff) bool {
func TestCommitParsing(t *testing.T) {
expected := expectedDiffs()

beforeProcesses := process.GetGitProcessList()

r := bytes.NewReader([]byte(commitLog))
diffChan := make(chan *Diff)
parser := NewParser()
Expand All @@ -809,6 +812,13 @@ func TestCommitParsing(t *testing.T) {
}
i++
}

afterProcesses := process.GetGitProcessList()
zombies := process.DetectGitZombies(beforeProcesses, afterProcesses)

if len(zombies) > 0 {
t.Errorf("Detected %d zombie git processes: %v", len(zombies), zombies)
}
}

func newBufferedFileWriterWithContent(content []byte) *bufferedfilewriter.BufferedFileWriter {
Expand Down
48 changes: 48 additions & 0 deletions pkg/process/zombies.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package process

import (
"os/exec"
"runtime"
"strings"
)

func GetGitProcessList() []string {
var cmd *exec.Cmd
if runtime.GOOS == "darwin" {
cmd = exec.Command("ps", "-eo", "pid,state,command")
} else {
cmd = exec.Command("ps", "-eo", "pid,stat,cmd")
}

output, err := cmd.Output()
if err != nil {
return nil
}

lines := strings.Split(string(output), "\n")
var gitProcesses []string
for _, line := range lines {
if strings.Contains(line, "git") {
gitProcesses = append(gitProcesses, line)
}
}
return gitProcesses
}

func DetectGitZombies(before, after []string) []string {
beforeMap := make(map[string]bool)
for _, process := range before {
beforeMap[process] = true
}

var zombies []string
for _, process := range after {
if !beforeMap[process] {
fields := strings.Fields(process)
if len(fields) >= 2 && (fields[1] == "Z" || strings.HasPrefix(fields[1], "Z")) {
zombies = append(zombies, process)
}
}
}
return zombies
}
26 changes: 19 additions & 7 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,32 +1241,44 @@ func (s *Git) handleBinary(ctx context.Context, gitDir string, reporter sources.
}

cmd := exec.Command("git", "-C", gitDir, "cat-file", "blob", commitHash.String()+":"+path)
stdout, err := s.executeCatFileCmd(cmd)
stdout, catCmd, err := s.executeCatFileCmd(cmd)
if err != nil {
return err
}
// Wait must be called after closing the pipe (defer is a stack, so first defer is executed last)
defer func() {
_ = catCmd.Wait()
}()
defer stdout.Close()

err = handlers.HandleFile(ctx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives))

if err = handlers.HandleFile(ctx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives)); err != nil {
// Always call Wait() to ensure the process is properly cleaned up
waitErr := cmd.Wait()

// If there was an error in HandleFile, return that error
if err != nil {
return err
}

return cmd.Wait()
// If Wait() resulted in an error, return that error
return waitErr
}

func (s *Git) executeCatFileCmd(cmd *exec.Cmd) (io.ReadCloser, error) {
func (s *Git) executeCatFileCmd(cmd *exec.Cmd) (io.ReadCloser, *exec.Cmd, error) {
var stderr bytes.Buffer
cmd.Stderr = &stderr

stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes())
return nil, nil, fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes())
}

if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes())
return nil, nil, fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes())
}

return stdout, nil
return stdout, cmd, nil
}

func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
Expand Down
9 changes: 9 additions & 0 deletions pkg/sources/git/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/process"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
)
Expand Down Expand Up @@ -235,6 +236,8 @@ func TestSource_Chunks_Integration(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
s := Source{}

beforeProcesses := process.GetGitProcessList()

conn, err := anypb.New(tt.init.connection)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -280,6 +283,12 @@ func TestSource_Chunks_Integration(t *testing.T) {
}

}

afterProcesses := process.GetGitProcessList()
zombies := process.DetectGitZombies(beforeProcesses, afterProcesses)
if len(zombies) > 0 {
t.Errorf("Git zombies detected: %v", zombies)
}
})
}
}
Expand Down

0 comments on commit f3630da

Please sign in to comment.